• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "shader_manager.h"
17 
18 #include <algorithm>
19 #include <cinttypes>
20 #include <cstring>
21 
22 #include <base/containers/array_view.h>
23 #include <core/io/intf_file_manager.h>
24 #include <render/device/gpu_resource_desc.h>
25 #include <render/device/pipeline_layout_desc.h>
26 #include <render/namespace.h>
27 
28 #include "device/device.h"
29 #include "device/gpu_program.h"
30 #include "device/gpu_program_util.h"
31 #include "device/gpu_resource_handle_util.h"
32 #include "device/shader_module.h"
33 #include "device/shader_pipeline_binder.h"
34 #include "loader/shader_loader.h"
35 #include "resource_handle_impl.h"
36 #include "util/log.h"
37 
38 using namespace BASE_NS;
39 using namespace CORE_NS;
40 
41 constexpr uint64_t IA_HASH_PRIMITIVE_TOPOLOGY_SHIFT = 1;
42 
43 constexpr uint64_t RS_HASH_POLYGON_MODE_SHIFT = 4;
44 constexpr uint64_t RS_HASH_CULL_MODE_SHIFT = 8;
45 constexpr uint64_t RS_HASH_FRONT_FACE_SHIFT = 12;
46 
47 constexpr uint64_t DSS_HASH_DEPTH_COMPARE_SHIFT = 4;
48 
49 constexpr uint64_t HASH_RS_SHIFT = 0;
50 constexpr uint64_t HASH_DS_SHIFT = 32;
51 constexpr uint64_t HASH_IA_SHIFT = 56;
52 
53 union FloatAsUint32 {
54     float f;
55     uint32_t ui;
56 };
57 
58 template<>
hash(const RENDER_NS::GraphicsState::InputAssembly & inputAssembly)59 uint64_t BASE_NS::hash(const RENDER_NS::GraphicsState::InputAssembly& inputAssembly)
60 {
61     uint64_t hash = 0;
62     hash |= static_cast<uint64_t>(inputAssembly.enablePrimitiveRestart);
63     hash |= (static_cast<uint64_t>(inputAssembly.primitiveTopology) << IA_HASH_PRIMITIVE_TOPOLOGY_SHIFT);
64     return hash;
65 }
66 
67 template<>
hash(const RENDER_NS::GraphicsState::RasterizationState & state)68 uint64_t BASE_NS::hash(const RENDER_NS::GraphicsState::RasterizationState& state)
69 {
70     uint64_t hash = 0;
71     hash |= (static_cast<uint64_t>(state.enableRasterizerDiscard) << 2u) |
72             (static_cast<uint64_t>(state.enableDepthBias) << 1u) | static_cast<uint64_t>(state.enableDepthClamp);
73     hash |= (static_cast<uint64_t>(state.polygonMode) << RS_HASH_POLYGON_MODE_SHIFT);
74     hash |= (static_cast<uint64_t>(state.cullModeFlags) << RS_HASH_CULL_MODE_SHIFT);
75     hash |= (static_cast<uint64_t>(state.frontFace) << RS_HASH_FRONT_FACE_SHIFT);
76     return hash;
77 }
78 
79 template<>
hash(const RENDER_NS::GraphicsState::DepthStencilState & state)80 uint64_t BASE_NS::hash(const RENDER_NS::GraphicsState::DepthStencilState& state)
81 {
82     uint64_t hash = 0;
83     hash |= (static_cast<uint64_t>(state.enableStencilTest) << 3u) |
84             (static_cast<uint64_t>(state.enableDepthBoundsTest) << 2u) |
85             (static_cast<uint64_t>(state.enableDepthWrite) << 1u) | static_cast<uint64_t>(state.enableDepthTest);
86     hash |= (static_cast<uint64_t>(state.depthCompareOp) << DSS_HASH_DEPTH_COMPARE_SHIFT);
87     return hash;
88 }
89 
90 template<>
hash(const RENDER_NS::GraphicsState::ColorBlendState::Attachment & state)91 uint64_t BASE_NS::hash(const RENDER_NS::GraphicsState::ColorBlendState::Attachment& state)
92 {
93     uint64_t hash = 0;
94     hash |= (static_cast<uint64_t>(state.enableBlend) << 0u);
95     // blend factor values 0 - 18, 0x1f for exact (5 bits)
96     hash |= (static_cast<uint64_t>(state.srcColorBlendFactor) << 1u);
97     hash |= ((static_cast<uint64_t>(state.dstColorBlendFactor) & 0x1f) << 6u);
98     hash |= ((static_cast<uint64_t>(state.srcAlphaBlendFactor) & 0x1f) << 12u);
99     hash |= ((static_cast<uint64_t>(state.dstAlphaBlendFactor) & 0x1f) << 18u);
100     // blend op values 0 - 4, 0x7 for exact (3 bits)
101     hash |= ((static_cast<uint64_t>(state.colorBlendOp) & 0x7) << 24u);
102     hash |= ((static_cast<uint64_t>(state.alphaBlendOp) & 0x7) << 28u);
103     return hash;
104 }
105 
106 template<>
hash(const RENDER_NS::GraphicsState::ColorBlendState & state)107 uint64_t BASE_NS::hash(const RENDER_NS::GraphicsState::ColorBlendState& state)
108 {
109     uint64_t hash = 0;
110     hash |= (static_cast<uint64_t>(state.enableLogicOp) << 0u);
111     hash |= (static_cast<uint64_t>(state.logicOp) << 1u);
112 
113     FloatAsUint32 vec[4u] = { { state.colorBlendConstants[0u] }, { state.colorBlendConstants[1u] },
114         { state.colorBlendConstants[2u] }, { state.colorBlendConstants[3u] } };
115     const uint64_t hashRG = (static_cast<uint64_t>(vec[0u].ui) << 32) | (vec[1u].ui);
116     const uint64_t hashBA = (static_cast<uint64_t>(vec[2u].ui) << 32) | (vec[3u].ui);
117     HashCombine(hash, hashRG, hashBA);
118     for (uint32_t idx = 0; idx < state.colorAttachmentCount; ++idx) {
119         HashCombine(hash, state.colorAttachments[idx]);
120     }
121     return hash;
122 }
123 
124 template<>
hash(const RENDER_NS::GraphicsState & state)125 uint64_t BASE_NS::hash(const RENDER_NS::GraphicsState& state)
126 {
127     const uint64_t iaHash = hash(state.inputAssembly);
128     const uint64_t rsHash = hash(state.rasterizationState);
129     const uint64_t dsHash = hash(state.depthStencilState);
130     const uint64_t dynHash = state.dynamicStateFlags;
131     const uint64_t cbsHash = hash(state.colorBlendState);
132     uint64_t finalHash = (iaHash << HASH_IA_SHIFT) | (rsHash << HASH_RS_SHIFT) | (dsHash << HASH_DS_SHIFT);
133     HashCombine(finalHash, dynHash);
134     HashCombine(finalHash, cbsHash);
135     return finalHash;
136 }
137 
138 RENDER_BEGIN_NAMESPACE()
139 namespace {
IsUniformBuffer(const DescriptorType descriptorType)140 constexpr inline bool IsUniformBuffer(const DescriptorType descriptorType)
141 {
142     return ((descriptorType == CORE_DESCRIPTOR_TYPE_UNIFORM_BUFFER_DYNAMIC) ||
143             (descriptorType == CORE_DESCRIPTOR_TYPE_UNIFORM_BUFFER));
144 }
IsStorageBuffer(const DescriptorType descriptorType)145 constexpr inline bool IsStorageBuffer(const DescriptorType descriptorType)
146 {
147     return ((descriptorType == CORE_DESCRIPTOR_TYPE_STORAGE_BUFFER_DYNAMIC) ||
148             (descriptorType == CORE_DESCRIPTOR_TYPE_STORAGE_BUFFER));
149 }
150 
GetPipelineLayoutCompatibilityFlags(const PipelineLayout & lhs,const PipelineLayout & rhs)151 ShaderManager::CompatibilityFlags GetPipelineLayoutCompatibilityFlags(
152     const PipelineLayout& lhs, const PipelineLayout& rhs)
153 {
154     ShaderManager::CompatibilityFlags flags = ShaderManager::CompatibilityFlagBits::COMPATIBLE_BIT;
155     for (uint32_t setIdx = 0; setIdx < PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT; ++setIdx) {
156         const auto& lSet = lhs.descriptorSetLayouts[setIdx];
157         const auto& rSet = rhs.descriptorSetLayouts[setIdx];
158         if (lSet.set == rSet.set) {
159             for (uint32_t lIdx = 0; lIdx < lSet.bindings.size(); ++lIdx) {
160                 const auto& lBind = lSet.bindings[lIdx];
161                 for (uint32_t rIdx = 0; rIdx < rSet.bindings.size(); ++rIdx) {
162                     const auto& rBind = rSet.bindings[rIdx];
163                     if (lBind.binding == rBind.binding) {
164                         if ((lBind.descriptorCount != rBind.descriptorCount) ||
165                             (lBind.descriptorType != rBind.descriptorType)) {
166                             // re-check dynamic offsets
167                             if ((IsUniformBuffer(lBind.descriptorType) != IsUniformBuffer(rBind.descriptorType)) &&
168                                 (IsStorageBuffer(lBind.descriptorType) != IsStorageBuffer(rBind.descriptorType))) {
169                                 flags = 0;
170                             }
171                         }
172                     }
173                 }
174             }
175         }
176     }
177     if (flags != 0) {
178         // check for exact match
179         bool isExact = true;
180         for (uint32_t setIdx = 0; setIdx < PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT; ++setIdx) {
181             const auto& lSet = lhs.descriptorSetLayouts[setIdx];
182             const auto& rSet = rhs.descriptorSetLayouts[setIdx];
183             if (lSet.set == rSet.set) {
184                 if (lSet.bindings.size() == rSet.bindings.size()) {
185                     for (size_t idx = 0; idx < lSet.bindings.size(); ++idx) {
186                         const int cmpRes =
187                             std::memcmp(&(lSet.bindings[idx]), &(rSet.bindings[idx]), sizeof(lSet.bindings[idx]));
188                         if (cmpRes != 0) {
189                             isExact = false;
190                             break;
191                         }
192                     }
193                 } else {
194                     isExact = false;
195                     break;
196                 }
197             }
198         }
199         if (isExact) {
200             flags |= ShaderManager::CompatibilityFlagBits::EXACT_BIT;
201         }
202     }
203     return flags;
204 }
205 
206 // NOTE: checking the type for validity is enough
IsComputeShaderFunc(RenderHandle handle)207 inline bool IsComputeShaderFunc(RenderHandle handle)
208 {
209     return RenderHandleType::COMPUTE_SHADER_STATE_OBJECT == RenderHandleUtil::GetHandleType(handle);
210 }
211 
IsShaderFunc(RenderHandle handle)212 inline bool IsShaderFunc(RenderHandle handle)
213 {
214     return RenderHandleType::SHADER_STATE_OBJECT == RenderHandleUtil::GetHandleType(handle);
215 }
216 
IsAnyShaderFunc(RenderHandle handle)217 inline bool IsAnyShaderFunc(RenderHandle handle)
218 {
219     return (RenderHandleType::COMPUTE_SHADER_STATE_OBJECT == RenderHandleUtil::GetHandleType(handle)) ||
220            (RenderHandleType::SHADER_STATE_OBJECT == RenderHandleUtil::GetHandleType(handle));
221 }
222 
GetShadersBySlot(const uint32_t renderSlotId,const ShaderManager::ComputeMappings & mappings,vector<RenderHandleReference> & shaders)223 inline void GetShadersBySlot(
224     const uint32_t renderSlotId, const ShaderManager::ComputeMappings& mappings, vector<RenderHandleReference>& shaders)
225 {
226     for (const auto& ref : mappings.clientData) {
227         if (ref.renderSlotId == renderSlotId) {
228             shaders.emplace_back(ref.rhr);
229         }
230     }
231 }
232 
GetShadersBySlot(const uint32_t renderSlotId,const ShaderManager::GraphicsMappings & mappings,vector<RenderHandleReference> & shaders)233 inline void GetShadersBySlot(const uint32_t renderSlotId, const ShaderManager::GraphicsMappings& mappings,
234     vector<RenderHandleReference>& shaders)
235 {
236     for (const auto& ref : mappings.clientData) {
237         if (ref.renderSlotId == renderSlotId) {
238             shaders.emplace_back(ref.rhr);
239         }
240     }
241 }
242 
GetShadersBySlot(const uint32_t renderSlotId,const ShaderManager::ComputeMappings & mappings,vector<RenderHandle> & shaders)243 inline void GetShadersBySlot(
244     const uint32_t renderSlotId, const ShaderManager::ComputeMappings& mappings, vector<RenderHandle>& shaders)
245 {
246     for (const auto& ref : mappings.clientData) {
247         if (ref.renderSlotId == renderSlotId) {
248             shaders.emplace_back(ref.rhr.GetHandle());
249         }
250     }
251 }
252 
GetShadersBySlot(const uint32_t renderSlotId,const ShaderManager::GraphicsMappings & mappings,vector<RenderHandle> & shaders)253 inline void GetShadersBySlot(
254     const uint32_t renderSlotId, const ShaderManager::GraphicsMappings& mappings, vector<RenderHandle>& shaders)
255 {
256     for (const auto& ref : mappings.clientData) {
257         if (ref.renderSlotId == renderSlotId) {
258             shaders.emplace_back(ref.rhr.GetHandle());
259         }
260     }
261 }
262 
GetHandle(const string_view name,const unordered_map<string,RenderHandle> & nameToClientHandle)263 inline RenderHandle GetHandle(const string_view name, const unordered_map<string, RenderHandle>& nameToClientHandle)
264 {
265     if (auto const pos = nameToClientHandle.find(name); pos != nameToClientHandle.end()) {
266         return pos->second;
267     }
268     return {};
269 }
270 
HashHandleAndSlot(const RenderHandle & handle,const uint32_t renderSlotId)271 constexpr inline uint64_t HashHandleAndSlot(const RenderHandle& handle, const uint32_t renderSlotId)
272 {
273     // normally there are < 16 render slot ids used which way less than 0xffff
274     PLUGIN_ASSERT(renderSlotId < 0xffff);
275     return (handle.id << 16ull) | (renderSlotId & 0xffff);
276 }
277 
GetBaseGraphicsStateVariantIndex(const ShaderManager::GraphicsStateData & graphicsStates,const ShaderManager::GraphicsStateVariantCreateInfo & vci)278 uint32_t GetBaseGraphicsStateVariantIndex(
279     const ShaderManager::GraphicsStateData& graphicsStates, const ShaderManager::GraphicsStateVariantCreateInfo& vci)
280 {
281     uint32_t baseVariantIndex = INVALID_SM_INDEX;
282     if (!vci.baseShaderState.empty()) {
283         const string fullBaseName = vci.baseShaderState + vci.baseVariant;
284         if (const auto bhIter = graphicsStates.nameToIndex.find(fullBaseName);
285             bhIter != graphicsStates.nameToIndex.cend()) {
286             PLUGIN_ASSERT(bhIter->second < graphicsStates.rhr.size());
287             if ((bhIter->second < graphicsStates.rhr.size()) && graphicsStates.rhr[bhIter->second]) {
288                 const RenderHandle baseHandle = graphicsStates.rhr[bhIter->second].GetHandle();
289                 baseVariantIndex = RenderHandleUtil::GetIndexPart(baseHandle);
290             }
291         } else {
292             PLUGIN_LOG_W("base state not found (%s %s)", vci.baseShaderState.data(), vci.baseVariant.data());
293         }
294     }
295     return baseVariantIndex;
296 }
297 } // namespace
298 
ShaderManager(Device & device)299 ShaderManager::ShaderManager(Device& device) : device_(device) {}
300 
301 ShaderManager::~ShaderManager() = default;
302 
HashGraphicsState(const GraphicsState & graphicsState) const303 uint64_t ShaderManager::HashGraphicsState(const GraphicsState& graphicsState) const
304 {
305     return BASE_NS::hash(graphicsState);
306 }
307 
CreateRenderSlotId(const string_view renderSlot)308 uint32_t ShaderManager::CreateRenderSlotId(const string_view renderSlot)
309 {
310     if (renderSlot.empty()) {
311         return INVALID_SM_INDEX;
312     }
313 
314     if (const auto iter = renderSlotIds_.nameToId.find(renderSlot); iter != renderSlotIds_.nameToId.cend()) {
315         return iter->second;
316     } else { // create new id
317         const uint32_t renderSlotId = static_cast<uint32_t>(renderSlotIds_.data.size());
318         renderSlotIds_.nameToId[renderSlot] = renderSlotId;
319         renderSlotIds_.data.push_back(RenderSlotData { renderSlotId, {}, {} });
320         return renderSlotId;
321     }
322 }
323 
SetRenderSlotData(const uint32_t renderSlotId,const RenderHandleReference & shaderHandle,const RenderHandleReference & stateHandle)324 void ShaderManager::SetRenderSlotData(
325     const uint32_t renderSlotId, const RenderHandleReference& shaderHandle, const RenderHandleReference& stateHandle)
326 {
327     if (renderSlotId < static_cast<uint32_t>(renderSlotIds_.data.size())) {
328         if (IsAnyShaderFunc(shaderHandle.GetHandle())) {
329             renderSlotIds_.data[renderSlotId].shader = shaderHandle;
330         }
331         if (RenderHandleUtil::GetHandleType(stateHandle.GetHandle()) == RenderHandleType::GRAPHICS_STATE) {
332             renderSlotIds_.data[renderSlotId].graphicsState = stateHandle;
333         }
334     }
335 }
336 
CreateClientData(const string_view name,const RenderHandleType type,const ClientDataIndices & cdi)337 RenderHandle ShaderManager::CreateClientData(
338     const string_view name, const RenderHandleType type, const ClientDataIndices& cdi)
339 {
340     RenderHandle clientHandle;
341     PLUGIN_ASSERT(
342         (type == RenderHandleType::COMPUTE_SHADER_STATE_OBJECT) || (type == RenderHandleType::SHADER_STATE_OBJECT));
343     if (const auto iter = nameToClientHandle_.find(name); iter != nameToClientHandle_.end()) {
344         clientHandle = iter->second;
345     } else {
346         const uint32_t arrayIndex = (type == RenderHandleType::COMPUTE_SHADER_STATE_OBJECT)
347                                         ? static_cast<uint32_t>(computeShaderMappings_.clientData.size())
348                                         : static_cast<uint32_t>(shaderMappings_.clientData.size());
349         clientHandle = RenderHandleUtil::CreateGpuResourceHandle(type, 0, arrayIndex, 0);
350         RenderHandleReference rhr =
351             RenderHandleReference(clientHandle, IRenderReferenceCounter::Ptr(new ShaderReferenceCounter()));
352         if (type == RenderHandleType::COMPUTE_SHADER_STATE_OBJECT) {
353             computeShaderMappings_.clientData.push_back(
354                 { move(rhr), {}, cdi.renderSlotIndex, cdi.pipelineLayoutIndex, cdi.reflectionPipelineLayoutIndex });
355         } else {
356             shaderMappings_.clientData.push_back(
357                 { move(rhr), {}, cdi.renderSlotIndex, cdi.pipelineLayoutIndex, cdi.reflectionPipelineLayoutIndex });
358         }
359         if (!name.empty()) {
360             nameToClientHandle_[name] = clientHandle;
361         }
362     }
363 
364     return clientHandle;
365 }
366 
Create(const ComputeShaderCreateData & createInfo,const string_view baseShaderPath,const string_view variantName)367 RenderHandleReference ShaderManager::Create(
368     const ComputeShaderCreateData& createInfo, const string_view baseShaderPath, const string_view variantName)
369 {
370     const string fullName = createInfo.path + variantName;
371 
372     // reflection pipeline layout
373     uint32_t reflectionPlIndex = INVALID_SM_INDEX;
374     if (const ShaderModule* cs = GetShaderModule(createInfo.shaderModuleIndex); cs) {
375         const RenderHandleReference plRhr = CreatePipelineLayout({ fullName, cs->GetPipelineLayout() });
376         reflectionPlIndex = RenderHandleUtil::GetIndexPart(plRhr.GetHandle());
377     }
378 
379     auto const clientHandle = CreateClientData(fullName, RenderHandleType::COMPUTE_SHADER_STATE_OBJECT,
380         { createInfo.renderSlotId, createInfo.pipelineLayoutIndex, reflectionPlIndex });
381     if (createInfo.pipelineLayoutIndex != INVALID_SM_INDEX) {
382         pl_.computeShaderToIndex[clientHandle] = createInfo.pipelineLayoutIndex;
383     }
384 
385     {
386         const auto lock = std::lock_guard(pendingMutex_);
387         pendingAllocations_.computeShaders.push_back(
388             { clientHandle, createInfo.shaderModuleIndex, createInfo.pipelineLayoutIndex });
389     }
390 
391     const uint32_t index = RenderHandleUtil::GetIndexPart(clientHandle);
392     if (IsComputeShaderFunc(clientHandle) &&
393         (index < static_cast<uint32_t>(computeShaderMappings_.clientData.size()))) {
394         auto& clientDataRef = computeShaderMappings_.clientData[index];
395         // add base shader if given
396         if (!baseShaderPath.empty()) {
397             if (const auto baseHandleIter = nameToClientHandle_.find(baseShaderPath);
398                 baseHandleIter != nameToClientHandle_.cend()) {
399                 if (RenderHandleUtil::IsValid(baseHandleIter->second)) {
400                     clientDataRef.baseShaderHandle = baseHandleIter->second;
401                     const uint64_t hash = HashHandleAndSlot(clientDataRef.baseShaderHandle, createInfo.renderSlotId);
402                     hashToShaderVariant_[hash] = clientHandle;
403                 }
404             } else {
405                 PLUGIN_LOG_W("base shader (%s) not found for (%s)", baseShaderPath.data(), createInfo.path.data());
406             }
407         }
408         return clientDataRef.rhr;
409     } else {
410         return {};
411     }
412 }
413 
Create(const ShaderCreateData & createInfo,const string_view baseShaderPath,const string_view variantName)414 RenderHandleReference ShaderManager::Create(
415     const ShaderCreateData& createInfo, const string_view baseShaderPath, const string_view variantName)
416 {
417     const string fullName = createInfo.path + variantName;
418 
419     // reflection pipeline layout
420     uint32_t reflectionPlIndex = INVALID_SM_INDEX;
421     {
422         const ShaderModule* vs = GetShaderModule(createInfo.vertShaderModuleIndex);
423         const ShaderModule* fs = GetShaderModule(createInfo.fragShaderModuleIndex);
424         if (vs && fs) {
425             const PipelineLayout layouts[] { vs->GetPipelineLayout(), fs->GetPipelineLayout() };
426             PipelineLayout pl;
427             GpuProgramUtil::CombinePipelineLayouts({ layouts, 2u }, pl);
428             const RenderHandleReference plRhr = CreatePipelineLayout({ fullName, pl });
429             reflectionPlIndex = RenderHandleUtil::GetIndexPart(plRhr.GetHandle());
430         }
431     }
432 
433     auto const clientHandle = CreateClientData(fullName, RenderHandleType::SHADER_STATE_OBJECT,
434         { createInfo.renderSlotId, createInfo.pipelineLayoutIndex, reflectionPlIndex });
435 
436     if (createInfo.pipelineLayoutIndex != INVALID_SM_INDEX) {
437         pl_.shaderToIndex[clientHandle] = createInfo.pipelineLayoutIndex;
438     }
439     if (createInfo.vertexInputDeclarationIndex != INVALID_SM_INDEX) {
440         shaderVid_.shaderToIndex[clientHandle] = createInfo.vertexInputDeclarationIndex;
441     }
442 
443     {
444         const auto lock = std::lock_guard(pendingMutex_);
445         pendingAllocations_.shaders.push_back({ clientHandle, createInfo.vertShaderModuleIndex,
446             createInfo.fragShaderModuleIndex, createInfo.pipelineLayoutIndex, createInfo.vertexInputDeclarationIndex });
447     }
448 
449     if (!createInfo.materialMetadata.empty()) {
450         MaterialMetadata metadata { string(createInfo.materialMetadata), json::value {} };
451         metadata.json = json::parse(metadata.raw.data());
452         if (metadata.json) {
453             shaderToMetadata_.insert({ clientHandle, move(metadata) });
454         }
455     }
456 
457     const uint32_t index = RenderHandleUtil::GetIndexPart(clientHandle);
458     if (IsShaderFunc(clientHandle) && (index < static_cast<uint32_t>(shaderMappings_.clientData.size()))) {
459         auto& clientDataRef = shaderMappings_.clientData[index];
460         clientDataRef.graphicsStateIndex = createInfo.graphicsStateIndex;
461         clientDataRef.vertexInputDeclarationIndex = createInfo.vertexInputDeclarationIndex;
462         // add base shader if given
463 #if (RENDER_VALIDATION_ENABLED == 1)
464         if ((!variantName.empty()) && baseShaderPath.empty()) {
465             PLUGIN_LOG_W("RENDER_VALIDATION: base shader path not give to variant (%s %s)", createInfo.path.data(),
466                 variantName.data());
467         }
468 #endif
469         if (!baseShaderPath.empty()) {
470             if (const auto baseHandleIter = nameToClientHandle_.find(baseShaderPath);
471                 baseHandleIter != nameToClientHandle_.cend()) {
472                 if (RenderHandleUtil::IsValid(baseHandleIter->second)) {
473                     clientDataRef.baseShaderHandle = baseHandleIter->second;
474                     const uint64_t hash = HashHandleAndSlot(clientDataRef.baseShaderHandle, createInfo.renderSlotId);
475                     hashToShaderVariant_[hash] = clientHandle;
476                 }
477             } else {
478                 PLUGIN_LOG_W("base shader (%s) not found for (%s)", baseShaderPath.data(), createInfo.path.data());
479             }
480         }
481         return clientDataRef.rhr;
482     } else {
483         return {};
484     }
485 }
486 
AddAdditionalNameForHandle(const RenderHandleReference & handle,const string_view name)487 void ShaderManager::AddAdditionalNameForHandle(const RenderHandleReference& handle, const string_view name)
488 {
489     if (handle) {
490         const RenderHandle rawHandle = handle.GetHandle();
491         const RenderHandleType handleType = RenderHandleUtil::GetHandleType(rawHandle);
492         // add name only if name not used yet
493         if ((handleType == RenderHandleType::COMPUTE_SHADER_STATE_OBJECT) ||
494             (handleType == RenderHandleType::SHADER_STATE_OBJECT)) {
495             if (!nameToClientHandle_.contains(name)) {
496                 nameToClientHandle_[name] = rawHandle;
497             } else {
498                 PLUGIN_LOG_W("trying to add additional name (%s) for shader handle, but the name is already in use",
499                     name.data());
500             }
501         }
502     }
503 }
504 
CreateComputeShader(const ComputeShaderCreateInfo & createInfo,const string_view baseShaderPath,const string_view variantName)505 RenderHandleReference ShaderManager::CreateComputeShader(
506     const ComputeShaderCreateInfo& createInfo, const string_view baseShaderPath, const string_view variantName)
507 {
508     if (createInfo.shaderPaths.size() >= 1u) {
509         if (const uint32_t moduleIdx = GetShaderModuleIndex(createInfo.shaderPaths[0].path);
510             moduleIdx != INVALID_SM_INDEX) {
511             return Create(ComputeShaderCreateData { createInfo.path, createInfo.renderSlotId,
512                               RenderHandleUtil::GetIndexPart(createInfo.pipelineLayout), moduleIdx },
513                 baseShaderPath, variantName);
514         } else {
515             PLUGIN_LOG_E("ShaderManager: compute shader (%s) creation failed, compute shader path (%s) not found",
516                 string(createInfo.path).c_str(), string(createInfo.shaderPaths[0].path).c_str());
517         }
518     } else {
519         PLUGIN_LOG_E("ShaderManager: compute shader (%s) creation failed, no shader module paths given",
520             string(createInfo.path).c_str());
521     }
522     return {};
523 }
524 
CreateComputeShader(const ComputeShaderCreateInfo & createInfo)525 RenderHandleReference ShaderManager::CreateComputeShader(const ComputeShaderCreateInfo& createInfo)
526 {
527     return CreateComputeShader(createInfo, "", "");
528 }
529 
CreateShader(const ShaderCreateInfo & createInfo,const string_view baseShaderPath,const string_view variantName)530 RenderHandleReference ShaderManager::CreateShader(
531     const ShaderCreateInfo& createInfo, const string_view baseShaderPath, const string_view variantName)
532 {
533     if (createInfo.shaderPaths.size() >= 2u) {
534         const uint32_t vertShaderModule = GetShaderModuleIndex(createInfo.shaderPaths[0u].path);
535         const uint32_t fragShaderModule = GetShaderModuleIndex(createInfo.shaderPaths[1u].path);
536         if ((vertShaderModule != INVALID_SM_INDEX) && (fragShaderModule != INVALID_SM_INDEX)) {
537             return Create(
538                 ShaderCreateData { createInfo.path, createInfo.renderSlotId,
539                     RenderHandleUtil::GetIndexPart(createInfo.vertexInputDeclaration),
540                     RenderHandleUtil::GetIndexPart(createInfo.pipelineLayout),
541                     RenderHandleUtil::GetIndexPart(createInfo.graphicsState), vertShaderModule, fragShaderModule, {} },
542                 baseShaderPath, variantName);
543         } else {
544             PLUGIN_LOG_E("ShaderManager: shader (%s) creation failed, shader path (vert:%s) (frag:%s) not found",
545                 string(createInfo.path).c_str(), string(createInfo.shaderPaths[0u].path).c_str(),
546                 string(createInfo.shaderPaths[1u].path).c_str());
547         }
548     } else {
549         PLUGIN_LOG_E("ShaderManager: shader (%s) creation failed, no shader module paths given",
550             string(createInfo.path).c_str());
551     }
552     return {};
553 }
554 
CreateShader(const ShaderCreateInfo & createInfo)555 RenderHandleReference ShaderManager::CreateShader(const ShaderCreateInfo& createInfo)
556 {
557     return CreateShader(createInfo, "", "");
558 }
559 
HandlePendingAllocations()560 void ShaderManager::HandlePendingAllocations()
561 {
562     pendingMutex_.lock();
563     decltype(pendingAllocations_) pendingAllocations = move(pendingAllocations_);
564     pendingMutex_.unlock();
565 
566     for (const auto& handleRef : pendingAllocations.destroyHandles) {
567         const RenderHandleType handleType = RenderHandleUtil::GetHandleType(handleRef);
568         const uint32_t arrayIndex = RenderHandleUtil::GetIndexPart(handleRef);
569         if (handleType == RenderHandleType::COMPUTE_SHADER_STATE_OBJECT) {
570             if (arrayIndex < static_cast<uint32_t>(computeShaders_.size())) {
571                 computeShaders_[arrayIndex] = {};
572             }
573         } else if (handleType == RenderHandleType::SHADER_STATE_OBJECT) {
574             if (arrayIndex < static_cast<uint32_t>(shaders_.size())) {
575                 shaders_[arrayIndex] = {};
576             }
577         }
578     }
579     HandlePendingShaders(pendingAllocations);
580     HandlePendingModules(pendingAllocations);
581 
582     const uint64_t frameCount = device_.GetFrameCount();
583     constexpr uint64_t additionalFrameCount { 2u };
584     const auto minAge = device_.GetCommandBufferingCount() + additionalFrameCount;
585     const auto ageLimit = (frameCount < minAge) ? 0 : (frameCount - minAge);
586     auto CompareForErase = [](const auto ageLimit, auto& vec) {
587         for (auto iter = vec.begin(); iter != vec.end();) {
588             if (iter->frameIndex < ageLimit) {
589                 iter = vec.erase(iter);
590             } else {
591                 ++iter;
592             }
593         }
594     };
595     CompareForErase(ageLimit, deferredDestructions_.shaderModules);
596     CompareForErase(ageLimit, deferredDestructions_.computePrograms);
597     CompareForErase(ageLimit, deferredDestructions_.shaderPrograms);
598 
599     hasReloadedShaders_ = false;
600 }
601 
HandlePendingShaders(Allocs & allocs)602 void ShaderManager::HandlePendingShaders(Allocs& allocs)
603 {
604     const uint64_t frameCount = device_.GetFrameCount();
605     for (const auto& ref : allocs.computeShaders) {
606         const uint32_t arrayIndex = RenderHandleUtil::GetIndexPart(ref.handle);
607         ShaderModule* shaderModule = GetShaderModule(ref.computeModuleIndex);
608         if (shaderModule) {
609             if (arrayIndex < static_cast<uint32_t>(computeShaders_.size())) {
610                 // replace with new (push old for deferred destruction)
611                 deferredDestructions_.computePrograms.push_back({ frameCount, move(computeShaders_[arrayIndex].gsp) });
612                 computeShaders_[arrayIndex] = { device_.CreateGpuComputeProgram({ shaderModule }),
613                     ref.pipelineLayoutIndex, ref.computeModuleIndex };
614             } else {
615                 // new gpu resource
616                 computeShaders_.push_back({ device_.CreateGpuComputeProgram({ shaderModule }), ref.pipelineLayoutIndex,
617                     ref.computeModuleIndex });
618             }
619         }
620 #if (RENDER_VALIDATION_ENABLED == 1)
621         if (!shaderModule) {
622             PLUGIN_LOG_E("RENDER_VALIDATION: Compute shader module with index:%u, not found", ref.computeModuleIndex);
623         }
624 #endif
625     }
626     for (const auto& ref : allocs.shaders) {
627         uint32_t arrayIndex = RenderHandleUtil::GetIndexPart(ref.handle);
628         ShaderModule* vertShaderModule = GetShaderModule(ref.vertModuleIndex);
629         ShaderModule* fragShaderModule = GetShaderModule(ref.fragModuleIndex);
630         if (vertShaderModule && fragShaderModule) {
631             if ((arrayIndex < static_cast<uint32_t>(shaders_.size()))) {
632                 // replace with new (push old for deferred destruction)
633                 deferredDestructions_.shaderPrograms.push_back({ frameCount, move(shaders_[arrayIndex].gsp) });
634                 shaders_[arrayIndex] = { device_.CreateGpuShaderProgram({ vertShaderModule, fragShaderModule }),
635                     ref.pipelineLayoutIndex, ref.vertexInputDeclIndex, ref.vertModuleIndex, ref.fragModuleIndex };
636             } else { // new gpu resource
637                 shaders_.push_back({ device_.CreateGpuShaderProgram({ vertShaderModule, fragShaderModule }),
638                     ref.pipelineLayoutIndex, ref.vertexInputDeclIndex, ref.vertModuleIndex, ref.fragModuleIndex });
639             }
640         }
641 #if (RENDER_VALIDATION_ENABLED == 1)
642         if ((!vertShaderModule) || (!fragShaderModule)) {
643             PLUGIN_LOG_E("RENDER_VALIDATION: Shader module with index: %u or %u, not found", ref.vertModuleIndex,
644                 ref.fragModuleIndex);
645         }
646 #endif
647     }
648 }
649 
HandlePendingModules(Allocs & allocs)650 void ShaderManager::HandlePendingModules(Allocs& allocs)
651 {
652     const uint64_t frameCount = device_.GetFrameCount();
653     for (const auto modIdx : allocs.recreatedComputeModuleIndices) {
654         for (auto& shaderRef : computeShaders_) {
655             if (modIdx == shaderRef.compModuleIndex) {
656                 if (ShaderModule* compModule = GetShaderModule(shaderRef.compModuleIndex); compModule) {
657                     deferredDestructions_.computePrograms.push_back({ frameCount, move(shaderRef.gsp) });
658                     shaderRef.gsp = device_.CreateGpuComputeProgram({ compModule });
659                 }
660             }
661         }
662     }
663     for (const auto modIdx : allocs.recreatedShaderModuleIndices) {
664         for (auto& shaderRef : shaders_) {
665             if ((modIdx == shaderRef.vertModuleIndex) || (modIdx == shaderRef.fragModuleIndex)) {
666                 ShaderModule* vertModule = GetShaderModule(shaderRef.vertModuleIndex);
667                 ShaderModule* fragModule = GetShaderModule(shaderRef.fragModuleIndex);
668                 if (vertModule && fragModule) {
669                     deferredDestructions_.shaderPrograms.push_back({ frameCount, move(shaderRef.gsp) });
670                     shaderRef.gsp = device_.CreateGpuShaderProgram({ vertModule, fragModule });
671                 }
672             }
673         }
674     }
675 }
676 
GetShaderHandle(const string_view name) const677 RenderHandleReference ShaderManager::GetShaderHandle(const string_view name) const
678 {
679     const RenderHandle handle = GetHandle(name, nameToClientHandle_);
680     const RenderHandleType handleType = RenderHandleUtil::GetHandleType(handle);
681     const uint32_t index = RenderHandleUtil::GetIndexPart(handle);
682     if ((handleType == RenderHandleType::COMPUTE_SHADER_STATE_OBJECT) &&
683         (index < static_cast<uint32_t>(computeShaderMappings_.clientData.size()))) {
684         return computeShaderMappings_.clientData[index].rhr;
685     } else if ((handleType == RenderHandleType::SHADER_STATE_OBJECT) &&
686                (index < static_cast<uint32_t>(shaderMappings_.clientData.size()))) {
687         return shaderMappings_.clientData[index].rhr;
688     } else {
689         PLUGIN_LOG_W("ShaderManager: invalid shader %s", name.data());
690         return {};
691     }
692 }
693 
GetShaderHandle(const string_view name,const string_view variantName) const694 RenderHandleReference ShaderManager::GetShaderHandle(const string_view name, const string_view variantName) const
695 {
696     const string fullName = name + variantName;
697     const RenderHandle handle = GetHandle(name, nameToClientHandle_);
698     const RenderHandleType handleType = RenderHandleUtil::GetHandleType(handle);
699     const uint32_t index = RenderHandleUtil::GetIndexPart(handle);
700     if ((handleType == RenderHandleType::COMPUTE_SHADER_STATE_OBJECT) &&
701         (index < static_cast<uint32_t>(computeShaderMappings_.clientData.size()))) {
702         return computeShaderMappings_.clientData[index].rhr;
703     } else if ((handleType == RenderHandleType::SHADER_STATE_OBJECT) &&
704                (index < static_cast<uint32_t>(shaderMappings_.clientData.size()))) {
705         return shaderMappings_.clientData[index].rhr;
706     } else {
707         PLUGIN_LOG_W("ShaderManager: invalid shader (%s) variant (%s)", name.data(), variantName.data());
708         return {};
709     }
710 }
711 
GetShaderHandle(const RenderHandle & handle,const uint32_t renderSlotId) const712 RenderHandleReference ShaderManager::GetShaderHandle(const RenderHandle& handle, const uint32_t renderSlotId) const
713 {
714     const RenderHandleType handleType = RenderHandleUtil::GetHandleType(handle);
715     if ((handleType != RenderHandleType::COMPUTE_SHADER_STATE_OBJECT) &&
716         (handleType != RenderHandleType::SHADER_STATE_OBJECT)) {
717         return {}; // early out
718     }
719 
720     const uint32_t index = RenderHandleUtil::GetIndexPart(handle);
721     RenderHandle baseShaderHandle;
722     // check first for own validity and possible base shader handle
723     if ((handleType == RenderHandleType::COMPUTE_SHADER_STATE_OBJECT) &&
724         (index < static_cast<uint32_t>(computeShaderMappings_.clientData.size()))) {
725         const auto& ref = computeShaderMappings_.clientData[index];
726         if (ref.renderSlotId == renderSlotId) {
727             return ref.rhr;
728         }
729         baseShaderHandle = ref.baseShaderHandle;
730     } else if ((handleType == RenderHandleType::SHADER_STATE_OBJECT) &&
731                (index < static_cast<uint32_t>(shaderMappings_.clientData.size()))) {
732         const auto& ref = shaderMappings_.clientData[index];
733         if (ref.renderSlotId == renderSlotId) {
734             return ref.rhr;
735         }
736         baseShaderHandle = ref.baseShaderHandle;
737     }
738     // try to find a match through base shader variant
739     if (RenderHandleUtil::IsValid(baseShaderHandle)) {
740         const uint64_t hash = HashHandleAndSlot(baseShaderHandle, renderSlotId);
741         if (const auto iter = hashToShaderVariant_.find(hash); iter != hashToShaderVariant_.cend()) {
742             const RenderHandleType baseHandleType = RenderHandleUtil::GetHandleType(iter->second);
743             const uint32_t arrayIndex = RenderHandleUtil::GetIndexPart(iter->second);
744             if ((baseHandleType == RenderHandleType::COMPUTE_SHADER_STATE_OBJECT) &&
745                 (arrayIndex < computeShaderMappings_.clientData.size())) {
746                 PLUGIN_ASSERT(computeShaderMappings_.clientData[arrayIndex].renderSlotId == renderSlotId);
747                 return computeShaderMappings_.clientData[arrayIndex].rhr;
748             } else if ((baseHandleType == RenderHandleType::SHADER_STATE_OBJECT) &&
749                        (arrayIndex < shaderMappings_.clientData.size())) {
750                 PLUGIN_ASSERT(shaderMappings_.clientData[arrayIndex].renderSlotId == renderSlotId);
751                 return shaderMappings_.clientData[arrayIndex].rhr;
752             }
753         }
754     }
755     return {};
756 }
757 
GetShaderHandle(const RenderHandleReference & handle,const uint32_t renderSlotId) const758 RenderHandleReference ShaderManager::GetShaderHandle(
759     const RenderHandleReference& handle, const uint32_t renderSlotId) const
760 {
761     return GetShaderHandle(handle.GetHandle(), renderSlotId);
762 }
763 
GetShaders(const uint32_t renderSlotId) const764 vector<RenderHandleReference> ShaderManager::GetShaders(const uint32_t renderSlotId) const
765 {
766     vector<RenderHandleReference> shaders;
767     GetShadersBySlot(renderSlotId, shaderMappings_, shaders);
768     GetShadersBySlot(renderSlotId, computeShaderMappings_, shaders);
769     return shaders;
770 }
771 
GetShaderRawHandles(const uint32_t renderSlotId) const772 vector<RenderHandle> ShaderManager::GetShaderRawHandles(const uint32_t renderSlotId) const
773 {
774     vector<RenderHandle> shaders;
775     GetShadersBySlot(renderSlotId, shaderMappings_, shaders);
776     GetShadersBySlot(renderSlotId, computeShaderMappings_, shaders);
777     return shaders;
778 }
779 
CreateGraphicsState(const GraphicsStateCreateInfo & createInfo,const GraphicsStateVariantCreateInfo & variantCreateInfo)780 RenderHandleReference ShaderManager::CreateGraphicsState(
781     const GraphicsStateCreateInfo& createInfo, const GraphicsStateVariantCreateInfo& variantCreateInfo)
782 {
783     PLUGIN_ASSERT(graphicsStates_.rhr.size() == graphicsStates_.graphicsStates.size());
784     const uint32_t renderSlotId = CreateRenderSlotId(variantCreateInfo.renderSlot);
785     // NOTE: No collisions expected if path is used
786     const string fullName = createInfo.path + variantCreateInfo.variant;
787     uint32_t arrayIndex = INVALID_SM_INDEX;
788     if (auto nameIter = graphicsStates_.nameToIndex.find(fullName); nameIter != graphicsStates_.nameToIndex.end()) {
789         arrayIndex = static_cast<uint32_t>(nameIter->second);
790     }
791 
792     uint32_t baseVariantIndex = INVALID_SM_INDEX;
793     RenderHandleReference rhr;
794     if (arrayIndex < graphicsStates_.rhr.size()) {
795         rhr = graphicsStates_.rhr[arrayIndex];
796         graphicsStates_.graphicsStates[arrayIndex] = createInfo.graphicsState;
797         const uint64_t hash = HashGraphicsState(createInfo.graphicsState);
798         baseVariantIndex = GetBaseGraphicsStateVariantIndex(graphicsStates_, variantCreateInfo);
799         graphicsStates_.data[arrayIndex] = { hash, renderSlotId, baseVariantIndex };
800         graphicsStates_.hashToIndex[hash] = arrayIndex;
801     } else { // new
802         arrayIndex = static_cast<uint32_t>(graphicsStates_.rhr.size());
803         // NOTE: these are only updated for new states
804         if (!fullName.empty()) {
805             graphicsStates_.nameToIndex[fullName] = arrayIndex;
806         }
807         const RenderHandle handle = RenderHandleUtil::CreateHandle(RenderHandleType::GRAPHICS_STATE, arrayIndex);
808         graphicsStates_.rhr.emplace_back(
809             RenderHandleReference(handle, IRenderReferenceCounter::Ptr(new ShaderReferenceCounter())));
810         rhr = graphicsStates_.rhr[arrayIndex];
811         graphicsStates_.graphicsStates.emplace_back(createInfo.graphicsState);
812         const uint64_t hash = HashGraphicsState(createInfo.graphicsState);
813         // ordering matters, this fetches from nameToIndex
814         baseVariantIndex = GetBaseGraphicsStateVariantIndex(graphicsStates_, variantCreateInfo);
815         graphicsStates_.data.push_back({ hash, renderSlotId, baseVariantIndex });
816         graphicsStates_.hashToIndex[hash] = arrayIndex;
817     }
818     if (baseVariantIndex < graphicsStates_.rhr.size()) {
819         const uint64_t variantHash = HashHandleAndSlot(graphicsStates_.rhr[baseVariantIndex].GetHandle(), renderSlotId);
820         if (variantHash != INVALID_SM_INDEX) {
821 #if (RENDER_VALIDATION_ENABLED == 1)
822             if (graphicsStates_.variantHashToIndex.contains(variantHash)) {
823                 PLUGIN_LOG_W("RENDER_VALIDATION: overwriting variant hash with %s %s", createInfo.path.data(),
824                     variantCreateInfo.variant.data());
825             }
826 #endif
827             graphicsStates_.variantHashToIndex[variantHash] = RenderHandleUtil::GetIndexPart(rhr.GetHandle());
828         }
829     }
830 
831     return rhr;
832 }
833 
CreateGraphicsState(const GraphicsStateCreateInfo & createInfo)834 RenderHandleReference ShaderManager::CreateGraphicsState(const GraphicsStateCreateInfo& createInfo)
835 {
836     return CreateGraphicsState(createInfo, {});
837 }
838 
GetGraphicsStateHandle(const string_view name) const839 RenderHandleReference ShaderManager::GetGraphicsStateHandle(const string_view name) const
840 {
841     if (const auto iter = graphicsStates_.nameToIndex.find(name); iter != graphicsStates_.nameToIndex.cend()) {
842         PLUGIN_ASSERT(iter->second < graphicsStates_.rhr.size());
843         return graphicsStates_.rhr[iter->second];
844     } else {
845         PLUGIN_LOG_W("ShaderManager: named graphics state not found: %s", string(name).c_str());
846         return {};
847     }
848 }
849 
GetGraphicsStateHandle(const string_view name,const string_view variantName) const850 RenderHandleReference ShaderManager::GetGraphicsStateHandle(const string_view name, const string_view variantName) const
851 {
852     // NOTE: does not call the base GetGraphicsStateHandle due to better error logging
853     const string fullName = string(name + variantName);
854     if (const auto iter = graphicsStates_.nameToIndex.find(fullName); iter != graphicsStates_.nameToIndex.cend()) {
855         PLUGIN_ASSERT(iter->second < graphicsStates_.rhr.size());
856         return graphicsStates_.rhr[iter->second];
857     } else {
858         PLUGIN_LOG_W(
859             "ShaderManager: named graphics state not found (name: %s variant: %s)", name.data(), variantName.data());
860         return {};
861     }
862 }
863 
GetGraphicsStateHandle(const RenderHandle & handle,const uint32_t renderSlotId) const864 RenderHandleReference ShaderManager::GetGraphicsStateHandle(
865     const RenderHandle& handle, const uint32_t renderSlotId) const
866 {
867     if (RenderHandleUtil::GetHandleType(handle) == RenderHandleType::GRAPHICS_STATE) {
868         const uint32_t arrayIndex = RenderHandleUtil::GetIndexPart(handle);
869         if (arrayIndex < static_cast<uint32_t>(graphicsStates_.data.size())) {
870             // check for own validity
871             const auto& data = graphicsStates_.data[arrayIndex];
872             if (renderSlotId == data.renderSlotId) {
873                 return graphicsStates_.rhr[arrayIndex];
874             }
875             // check for base variant for hashing
876             if (data.baseVariantIndex < static_cast<uint32_t>(graphicsStates_.data.size())) {
877                 const RenderHandle baseHandle = graphicsStates_.rhr[data.baseVariantIndex].GetHandle();
878                 const uint64_t hash = HashHandleAndSlot(baseHandle, renderSlotId);
879                 if (const auto iter = graphicsStates_.variantHashToIndex.find(hash);
880                     iter != graphicsStates_.variantHashToIndex.cend()) {
881                     PLUGIN_ASSERT(iter->second < static_cast<uint32_t>(graphicsStates_.rhr.size()));
882                     return graphicsStates_.rhr[iter->second];
883                 }
884             }
885         }
886     }
887     return {};
888 }
889 
GetGraphicsStateHandle(const RenderHandleReference & handle,const uint32_t renderSlotId) const890 RenderHandleReference ShaderManager::GetGraphicsStateHandle(
891     const RenderHandleReference& handle, const uint32_t renderSlotId) const
892 {
893     return GetGraphicsStateHandle(handle.GetHandle(), renderSlotId);
894 }
895 
GetGraphicsStateHandleByHash(const uint64_t hash) const896 RenderHandleReference ShaderManager::GetGraphicsStateHandleByHash(const uint64_t hash) const
897 {
898     if (const auto iter = graphicsStates_.hashToIndex.find(hash); iter != graphicsStates_.hashToIndex.cend()) {
899         PLUGIN_ASSERT(iter->second < graphicsStates_.rhr.size());
900         return graphicsStates_.rhr[iter->second];
901     } else {
902         return {};
903     }
904 }
905 
GetGraphicsStateHandleByShaderHandle(const RenderHandle & handle) const906 RenderHandleReference ShaderManager::GetGraphicsStateHandleByShaderHandle(const RenderHandle& handle) const
907 {
908     if (RenderHandleUtil::GetHandleType(handle) == RenderHandleType::SHADER_STATE_OBJECT) {
909         const uint32_t arrayIndex = RenderHandleUtil::GetIndexPart(handle);
910         if (arrayIndex < static_cast<uint32_t>(shaderMappings_.clientData.size())) {
911             const uint32_t gsIndex = shaderMappings_.clientData[arrayIndex].graphicsStateIndex;
912             if (gsIndex < static_cast<uint32_t>(graphicsStates_.graphicsStates.size())) {
913                 return graphicsStates_.rhr[gsIndex];
914             }
915 #if (RENDER_VALIDATION_ENABLED == 1)
916             PLUGIN_ASSERT(gsIndex != INVALID_SM_INDEX); // not and optional index ATM
917             PLUGIN_ASSERT(gsIndex < graphicsStates_.rhr.size());
918 #endif
919         }
920     }
921     return {};
922 }
923 
GetGraphicsStateHandleByShaderHandle(const RenderHandleReference & handle) const924 RenderHandleReference ShaderManager::GetGraphicsStateHandleByShaderHandle(const RenderHandleReference& handle) const
925 {
926     return GetGraphicsStateHandleByShaderHandle(handle.GetHandle());
927 }
928 
GetGraphicsState(const RenderHandleReference & handle) const929 GraphicsState ShaderManager::GetGraphicsState(const RenderHandleReference& handle) const
930 {
931     return GetGraphicsStateRef(handle);
932 }
933 
GetGraphicsStateRef(const RenderHandle & handle) const934 const GraphicsState& ShaderManager::GetGraphicsStateRef(const RenderHandle& handle) const
935 {
936     const RenderHandleType type = RenderHandleUtil::GetHandleType(handle);
937     const uint32_t arrayIndex = RenderHandleUtil::GetIndexPart(handle);
938     if ((type == RenderHandleType::GRAPHICS_STATE) &&
939         (arrayIndex < static_cast<uint32_t>(graphicsStates_.graphicsStates.size()))) {
940         return graphicsStates_.graphicsStates[arrayIndex];
941     } else {
942 #if (RENDER_VALIDATION_ENABLED == 1)
943         if (RenderHandleUtil::IsValid(handle) && (type != RenderHandleType::GRAPHICS_STATE)) {
944             PLUGIN_LOG_W("RENDER_VALIDATION: invalid handle type given to GetGraphicsState()");
945         }
946 #endif
947         return defaultGraphicsState_;
948     }
949 }
950 
GetGraphicsStateRef(const RenderHandleReference & handle) const951 const GraphicsState& ShaderManager::GetGraphicsStateRef(const RenderHandleReference& handle) const
952 {
953     return GetGraphicsStateRef(handle.GetHandle());
954 }
955 
GetRenderSlotId(const string_view renderSlot) const956 uint32_t ShaderManager::GetRenderSlotId(const string_view renderSlot) const
957 {
958     if (const auto iter = renderSlotIds_.nameToId.find(renderSlot); iter != renderSlotIds_.nameToId.cend()) {
959         return iter->second;
960     } else {
961         return INVALID_SM_INDEX;
962     }
963 }
964 
GetRenderSlotId(const RenderHandle & handle) const965 uint32_t ShaderManager::GetRenderSlotId(const RenderHandle& handle) const
966 {
967     uint32_t id = ~0u;
968     const RenderHandleType handleType = RenderHandleUtil::GetHandleType(handle);
969     const uint32_t arrayIndex = RenderHandleUtil::GetIndexPart(handle);
970     if (handleType == RenderHandleType::COMPUTE_SHADER_STATE_OBJECT) {
971         if (arrayIndex < computeShaderMappings_.clientData.size()) {
972             id = computeShaderMappings_.clientData[arrayIndex].renderSlotId;
973         }
974     } else if (handleType == RenderHandleType::SHADER_STATE_OBJECT) {
975         if (arrayIndex < shaderMappings_.clientData.size()) {
976             id = shaderMappings_.clientData[arrayIndex].renderSlotId;
977         }
978     } else if (handleType == RenderHandleType::GRAPHICS_STATE) {
979         if (arrayIndex < graphicsStates_.data.size()) {
980             id = graphicsStates_.data[arrayIndex].renderSlotId;
981         }
982     }
983     return id;
984 }
985 
GetRenderSlotId(const RenderHandleReference & handle) const986 uint32_t ShaderManager::GetRenderSlotId(const RenderHandleReference& handle) const
987 {
988     return GetRenderSlotId(handle.GetHandle());
989 }
990 
GetRenderSlotData(const uint32_t renderSlotId) const991 IShaderManager::RenderSlotData ShaderManager::GetRenderSlotData(const uint32_t renderSlotId) const
992 {
993     if (renderSlotId < static_cast<uint32_t>(renderSlotIds_.data.size())) {
994         return renderSlotIds_.data[renderSlotId];
995     } else {
996         return {};
997     }
998 }
999 
GetVertexInputDeclarationHandleByShaderHandle(const RenderHandle & handle) const1000 RenderHandleReference ShaderManager::GetVertexInputDeclarationHandleByShaderHandle(const RenderHandle& handle) const
1001 {
1002     if (RenderHandleUtil::GetHandleType(handle) == RenderHandleType::SHADER_STATE_OBJECT) {
1003         const uint32_t arrayIndex = RenderHandleUtil::GetIndexPart(handle);
1004         auto& mappings = shaderMappings_;
1005         if (arrayIndex < mappings.clientData.size()) {
1006             const uint32_t vidIndex = mappings.clientData[arrayIndex].vertexInputDeclarationIndex;
1007             if (vidIndex < shaderVid_.rhr.size()) {
1008                 PLUGIN_ASSERT(vidIndex < shaderVid_.rhr.size());
1009                 return shaderVid_.rhr[vidIndex];
1010             }
1011         }
1012     }
1013     return {};
1014 }
1015 
GetVertexInputDeclarationHandleByShaderHandle(const RenderHandleReference & handle) const1016 RenderHandleReference ShaderManager::GetVertexInputDeclarationHandleByShaderHandle(
1017     const RenderHandleReference& handle) const
1018 {
1019     return GetVertexInputDeclarationHandleByShaderHandle(handle.GetHandle());
1020 }
1021 
GetVertexInputDeclarationHandle(const string_view name) const1022 RenderHandleReference ShaderManager::GetVertexInputDeclarationHandle(const string_view name) const
1023 {
1024     if (const auto iter = shaderVid_.nameToIndex.find(name); iter != shaderVid_.nameToIndex.cend()) {
1025         PLUGIN_ASSERT(iter->second < shaderVid_.rhr.size());
1026         return shaderVid_.rhr[iter->second];
1027     } else {
1028         PLUGIN_LOG_W("ShaderManager: named vertex input declaration not found: %s", name.data());
1029         return {};
1030     }
1031 }
1032 
GetVertexInputDeclarationView(const RenderHandle & handle) const1033 VertexInputDeclarationView ShaderManager::GetVertexInputDeclarationView(const RenderHandle& handle) const
1034 {
1035     const RenderHandleType type = RenderHandleUtil::GetHandleType(handle);
1036     const uint32_t index = RenderHandleUtil::GetIndexPart(handle);
1037     if ((type == RenderHandleType::VERTEX_INPUT_DECLARATION) &&
1038         (index < static_cast<uint32_t>(shaderVid_.data.size()))) {
1039         const auto& ref = shaderVid_.data[index];
1040         return {
1041             array_view<const VertexInputDeclaration::VertexInputBindingDescription>(
1042                 ref.bindingDescriptions, ref.bindingDescriptionCount),
1043             array_view<const VertexInputDeclaration::VertexInputAttributeDescription>(
1044                 ref.attributeDescriptions, ref.attributeDescriptionCount),
1045         };
1046     } else {
1047 #if (RENDER_VALIDATION_ENABLED == 1)
1048         if (RenderHandleUtil::IsValid(handle) && (type != RenderHandleType::VERTEX_INPUT_DECLARATION)) {
1049             PLUGIN_LOG_W("RENDER_VALIDATION: invalid handle type given to GetVertexInputDeclarationView()");
1050         }
1051 #endif
1052         return {};
1053     }
1054 }
1055 
GetVertexInputDeclarationView(const RenderHandleReference & handle) const1056 VertexInputDeclarationView ShaderManager::GetVertexInputDeclarationView(const RenderHandleReference& handle) const
1057 {
1058     return GetVertexInputDeclarationView(handle.GetHandle());
1059 }
1060 
CreateVertexInputDeclaration(const VertexInputDeclarationCreateInfo & createInfo)1061 RenderHandleReference ShaderManager::CreateVertexInputDeclaration(const VertexInputDeclarationCreateInfo& createInfo)
1062 {
1063     uint32_t arrayIndex = INVALID_SM_INDEX;
1064     if (auto nameIter = shaderVid_.nameToIndex.find(createInfo.path); nameIter != shaderVid_.nameToIndex.end()) {
1065         PLUGIN_ASSERT(nameIter->second < shaderVid_.rhr.size());
1066         arrayIndex = static_cast<uint32_t>(nameIter->second);
1067     }
1068     if (arrayIndex < static_cast<uint32_t>(shaderVid_.data.size())) {
1069         // inside core validation due to being very low info for common users
1070 #if (RENDER_VALIDATION_ENABLED == 1)
1071         PLUGIN_LOG_I("ShaderManager: re-creating vertex input declaration (name %s)", createInfo.path.data());
1072 #endif
1073     } else { // new
1074         arrayIndex = static_cast<uint32_t>(shaderVid_.data.size());
1075         const RenderHandle handle =
1076             RenderHandleUtil::CreateHandle(RenderHandleType::VERTEX_INPUT_DECLARATION, arrayIndex);
1077         shaderVid_.rhr.emplace_back(
1078             RenderHandleReference(handle, IRenderReferenceCounter::Ptr(new ShaderReferenceCounter())));
1079         shaderVid_.data.emplace_back(VertexInputDeclarationData {});
1080         // NOTE: only updated for new
1081         if (!createInfo.path.empty()) {
1082             shaderVid_.nameToIndex[createInfo.path] = arrayIndex;
1083         }
1084     }
1085 
1086     if (arrayIndex < static_cast<uint32_t>(shaderVid_.data.size())) {
1087         const VertexInputDeclarationView& vertexInputDeclarationView = createInfo.vertexInputDeclarationView;
1088         VertexInputDeclarationData& ref = shaderVid_.data[arrayIndex];
1089         ref.bindingDescriptionCount = (uint32_t)vertexInputDeclarationView.bindingDescriptions.size();
1090         ref.attributeDescriptionCount = (uint32_t)vertexInputDeclarationView.attributeDescriptions.size();
1091 
1092         PLUGIN_ASSERT(ref.bindingDescriptionCount <= PipelineStateConstants::MAX_VERTEX_BUFFER_COUNT);
1093         PLUGIN_ASSERT(ref.attributeDescriptionCount <= PipelineStateConstants::MAX_VERTEX_BUFFER_COUNT);
1094 
1095         for (uint32_t idx = 0; idx < ref.bindingDescriptionCount; ++idx) {
1096             ref.bindingDescriptions[idx] = vertexInputDeclarationView.bindingDescriptions[idx];
1097         }
1098         for (uint32_t idx = 0; idx < ref.attributeDescriptionCount; ++idx) {
1099             ref.attributeDescriptions[idx] = vertexInputDeclarationView.attributeDescriptions[idx];
1100         }
1101         return shaderVid_.rhr[arrayIndex];
1102     } else {
1103         return {};
1104     }
1105 }
1106 
GetPipelineLayoutHandleByShaderHandle(const RenderHandle & handle) const1107 RenderHandleReference ShaderManager::GetPipelineLayoutHandleByShaderHandle(const RenderHandle& handle) const
1108 {
1109     const RenderHandleType type = RenderHandleUtil::GetHandleType(handle);
1110     const uint32_t arrayIndex = RenderHandleUtil::GetIndexPart(handle);
1111     if (type == RenderHandleType::SHADER_STATE_OBJECT) {
1112         auto& mappings = shaderMappings_;
1113         if (arrayIndex < mappings.clientData.size()) {
1114             const uint32_t plIndex = mappings.clientData[arrayIndex].pipelineLayoutIndex;
1115             if (plIndex < static_cast<uint32_t>(pl_.rhr.size())) {
1116                 PLUGIN_ASSERT(plIndex < static_cast<uint32_t>(pl_.rhr.size()));
1117                 return pl_.rhr[plIndex];
1118             }
1119         }
1120     } else if (type == RenderHandleType::COMPUTE_SHADER_STATE_OBJECT) {
1121         auto& mappings = computeShaderMappings_;
1122         if (arrayIndex < mappings.clientData.size()) {
1123             const uint32_t plIndex = mappings.clientData[arrayIndex].pipelineLayoutIndex;
1124             if (plIndex < static_cast<uint32_t>(pl_.rhr.size())) {
1125                 PLUGIN_ASSERT(plIndex < static_cast<uint32_t>(pl_.rhr.size()));
1126                 return pl_.rhr[plIndex];
1127             }
1128         }
1129     }
1130     return {};
1131 }
1132 
GetPipelineLayoutHandleByShaderHandle(const RenderHandleReference & handle) const1133 RenderHandleReference ShaderManager::GetPipelineLayoutHandleByShaderHandle(const RenderHandleReference& handle) const
1134 {
1135     return GetPipelineLayoutHandleByShaderHandle(handle.GetHandle());
1136 }
1137 
GetPipelineLayoutHandle(const string_view name) const1138 RenderHandleReference ShaderManager::GetPipelineLayoutHandle(const string_view name) const
1139 {
1140     if (const auto iter = pl_.nameToIndex.find(name); iter != pl_.nameToIndex.cend()) {
1141         const uint32_t index = iter->second;
1142         PLUGIN_ASSERT(index < static_cast<uint32_t>(pl_.rhr.size()));
1143         return pl_.rhr[index];
1144     } else {
1145         PLUGIN_LOG_W("ShaderManager: named pipeline layout not found: %s", name.data());
1146         return {};
1147     }
1148 }
1149 
GetPipelineLayout(const RenderHandle & handle) const1150 PipelineLayout ShaderManager::GetPipelineLayout(const RenderHandle& handle) const
1151 {
1152     return GetPipelineLayoutRef(handle);
1153 }
1154 
GetPipelineLayout(const RenderHandleReference & handle) const1155 PipelineLayout ShaderManager::GetPipelineLayout(const RenderHandleReference& handle) const
1156 {
1157     return GetPipelineLayoutRef(handle.GetHandle());
1158 }
1159 
GetPipelineLayoutRef(const RenderHandle & handle) const1160 const PipelineLayout& ShaderManager::GetPipelineLayoutRef(const RenderHandle& handle) const
1161 {
1162     const RenderHandleType type = RenderHandleUtil::GetHandleType(handle);
1163     const uint32_t index = RenderHandleUtil::GetIndexPart(handle);
1164     if ((type == RenderHandleType::PIPELINE_LAYOUT) && (index < static_cast<uint32_t>(pl_.data.size()))) {
1165         return pl_.data[index];
1166     } else {
1167 #if (RENDER_VALIDATION_ENABLED == 1)
1168         if (RenderHandleUtil::IsValid(handle) && (type != RenderHandleType::PIPELINE_LAYOUT)) {
1169             PLUGIN_LOG_W("RENDER_VALIDATION: invalid handle type given to GetPipelineLayout()");
1170         }
1171 #endif
1172         return defaultPipelineLayout_;
1173     }
1174 }
1175 
GetReflectionPipelineLayoutHandle(const RenderHandle & handle) const1176 RenderHandleReference ShaderManager::GetReflectionPipelineLayoutHandle(const RenderHandle& handle) const
1177 {
1178     const RenderHandleType type = RenderHandleUtil::GetHandleType(handle);
1179     const uint32_t arrayIndex = RenderHandleUtil::GetIndexPart(handle);
1180     uint32_t plIndex = INVALID_SM_INDEX;
1181     if (type == RenderHandleType::SHADER_STATE_OBJECT) {
1182         if (arrayIndex < shaderMappings_.clientData.size()) {
1183             plIndex = shaderMappings_.clientData[arrayIndex].reflectionPipelineLayoutIndex;
1184         }
1185     } else if (type == RenderHandleType::COMPUTE_SHADER_STATE_OBJECT) {
1186         if (arrayIndex < computeShaderMappings_.clientData.size()) {
1187             plIndex = computeShaderMappings_.clientData[arrayIndex].reflectionPipelineLayoutIndex;
1188         }
1189     }
1190 
1191     if (plIndex < pl_.data.size()) {
1192         return pl_.rhr[plIndex];
1193     } else {
1194 #if (RENDER_VALIDATION_ENABLED == 1)
1195         PLUGIN_LOG_W("RENDER_VALIDATION: ShaderManager, invalid shader handle for GetReflectionPipelineLayoutHandle");
1196 #endif
1197         return {};
1198     }
1199 }
1200 
GetReflectionPipelineLayoutHandle(const RenderHandleReference & handle) const1201 RenderHandleReference ShaderManager::GetReflectionPipelineLayoutHandle(const RenderHandleReference& handle) const
1202 {
1203     return GetReflectionPipelineLayoutHandle(handle.GetHandle());
1204 }
1205 
GetReflectionPipelineLayout(const RenderHandleReference & handle) const1206 PipelineLayout ShaderManager::GetReflectionPipelineLayout(const RenderHandleReference& handle) const
1207 {
1208     return GetReflectionPipelineLayoutRef(handle.GetHandle());
1209 }
1210 
GetReflectionPipelineLayoutRef(const RenderHandle & handle) const1211 const PipelineLayout& ShaderManager::GetReflectionPipelineLayoutRef(const RenderHandle& handle) const
1212 {
1213     const RenderHandleType type = RenderHandleUtil::GetHandleType(handle);
1214     const uint32_t arrayIndex = RenderHandleUtil::GetIndexPart(handle);
1215     uint32_t plIndex = INVALID_SM_INDEX;
1216     if (type == RenderHandleType::SHADER_STATE_OBJECT) {
1217         if (arrayIndex < shaderMappings_.clientData.size()) {
1218             plIndex = shaderMappings_.clientData[arrayIndex].reflectionPipelineLayoutIndex;
1219         }
1220     } else if (type == RenderHandleType::COMPUTE_SHADER_STATE_OBJECT) {
1221         if (arrayIndex < computeShaderMappings_.clientData.size()) {
1222             plIndex = computeShaderMappings_.clientData[arrayIndex].reflectionPipelineLayoutIndex;
1223         }
1224     }
1225 
1226     if (plIndex < pl_.data.size()) {
1227         return pl_.data[plIndex];
1228     } else {
1229 #if (RENDER_VALIDATION_ENABLED == 1)
1230         PLUGIN_LOG_W("RENDER_VALIDATION: ShaderManager, invalid shader handle for GetReflectionPipelineLayout");
1231 #endif
1232         return defaultPipelineLayout_;
1233     }
1234 }
1235 
GetReflectionSpecialization(const RenderHandle & handle) const1236 ShaderSpecilizationConstantView ShaderManager::GetReflectionSpecialization(const RenderHandle& handle) const
1237 {
1238     const RenderHandleType type = RenderHandleUtil::GetHandleType(handle);
1239     const uint32_t arrayIndex = RenderHandleUtil::GetIndexPart(handle);
1240     if (type == RenderHandleType::SHADER_STATE_OBJECT) {
1241         // NOTE: at the moment there might not be availability yet, will be FIXED
1242         if (arrayIndex < shaders_.size()) {
1243             if (shaders_[arrayIndex].gsp) {
1244                 return shaders_[arrayIndex].gsp->GetReflection().shaderSpecializationConstantView;
1245             }
1246         }
1247     } else if (type == RenderHandleType::COMPUTE_SHADER_STATE_OBJECT) {
1248         // NOTE: at the moment there might not be availability yet, will be FIXED
1249         if (arrayIndex < computeShaders_.size()) {
1250             if (computeShaders_[arrayIndex].gsp) {
1251                 return computeShaders_[arrayIndex].gsp->GetReflection().shaderSpecializationConstantView;
1252             }
1253         }
1254     }
1255 #if (RENDER_VALIDATION_ENABLED == 1)
1256     PLUGIN_LOG_W("RENDER_VALIDATION: ShaderManager, invalid shader handle for GetReflectionSpecialization");
1257 #endif
1258     return defaultSSCV_;
1259 }
1260 
GetReflectionSpecialization(const RenderHandleReference & handle) const1261 ShaderSpecilizationConstantView ShaderManager::GetReflectionSpecialization(const RenderHandleReference& handle) const
1262 {
1263     return GetReflectionSpecialization(handle.GetHandle());
1264 }
1265 
GetReflectionVertexInputDeclaration(const RenderHandle & handle) const1266 VertexInputDeclarationView ShaderManager::GetReflectionVertexInputDeclaration(const RenderHandle& handle) const
1267 {
1268     const RenderHandleType type = RenderHandleUtil::GetHandleType(handle);
1269     const uint32_t arrayIndex = RenderHandleUtil::GetIndexPart(handle);
1270     if (type == RenderHandleType::SHADER_STATE_OBJECT) {
1271         // NOTE: at the moment there might not be availability yet, will be FIXED
1272         if (arrayIndex < shaders_.size()) {
1273             if (shaders_[arrayIndex].gsp) {
1274                 return shaders_[arrayIndex].gsp->GetReflection().vertexInputDeclarationView;
1275             }
1276         }
1277     }
1278 #if (RENDER_VALIDATION_ENABLED == 1)
1279     PLUGIN_LOG_W("RENDER_VALIDATION: ShaderManager, invalid shader handle for GetReflectionVertexInputDeclaration");
1280 #endif
1281     return defaultVIDV_;
1282 }
1283 
GetReflectionVertexInputDeclaration(const RenderHandleReference & handle) const1284 VertexInputDeclarationView ShaderManager::GetReflectionVertexInputDeclaration(const RenderHandleReference& handle) const
1285 {
1286     return GetReflectionVertexInputDeclaration(handle.GetHandle());
1287 }
1288 
GetReflectionThreadGroupSize(const RenderHandle & handle) const1289 ShaderThreadGroup ShaderManager::GetReflectionThreadGroupSize(const RenderHandle& handle) const
1290 {
1291     const RenderHandleType type = RenderHandleUtil::GetHandleType(handle);
1292     const uint32_t arrayIndex = RenderHandleUtil::GetIndexPart(handle);
1293     if (type == RenderHandleType::COMPUTE_SHADER_STATE_OBJECT) {
1294         // NOTE: at the moment there might not be availability yet, will be FIXED
1295         if (arrayIndex < computeShaders_.size()) {
1296             if (computeShaders_[arrayIndex].gsp) {
1297                 const auto& refl = computeShaders_[arrayIndex].gsp->GetReflection();
1298                 return { refl.threadGroupSizeX, refl.threadGroupSizeY, refl.threadGroupSizeZ };
1299             }
1300         }
1301     }
1302 #if (RENDER_VALIDATION_ENABLED == 1)
1303     PLUGIN_LOG_W("RENDER_VALIDATION: ShaderManager, invalid shader handle for GetReflectionThreadGroupSize");
1304 #endif
1305     return defaultSTG_;
1306 }
1307 
GetReflectionThreadGroupSize(const RenderHandleReference & handle) const1308 ShaderThreadGroup ShaderManager::GetReflectionThreadGroupSize(const RenderHandleReference& handle) const
1309 {
1310     return GetReflectionThreadGroupSize(handle.GetHandle());
1311 }
1312 
CreatePipelineLayout(const PipelineLayoutCreateInfo & createInfo)1313 RenderHandleReference ShaderManager::CreatePipelineLayout(const PipelineLayoutCreateInfo& createInfo)
1314 {
1315     uint32_t arrayIndex = INVALID_SM_INDEX;
1316     if (auto nameIter = pl_.nameToIndex.find(createInfo.path); nameIter != pl_.nameToIndex.end()) {
1317         PLUGIN_ASSERT(nameIter->second < pl_.rhr.size());
1318         arrayIndex = static_cast<uint32_t>(nameIter->second);
1319     }
1320 
1321     if (arrayIndex < static_cast<uint32_t>(pl_.data.size())) { // replace
1322         // inside core validation due to being very low info for common users
1323 #if (RENDER_VALIDATION_ENABLED == 1)
1324         PLUGIN_LOG_I("ShaderManager: re-creating pipeline layout (name %s)", createInfo.path.data());
1325 #endif
1326     } else { // new
1327         arrayIndex = static_cast<uint32_t>(pl_.data.size());
1328         pl_.data.emplace_back(PipelineLayout {});
1329         // NOTE: only updated for new (should check with re-creation)
1330         if (!createInfo.path.empty()) {
1331             pl_.nameToIndex[createInfo.path] = arrayIndex;
1332         }
1333         pl_.rhr.emplace_back(RenderHandleReference {});
1334     }
1335 
1336     if (arrayIndex < static_cast<uint32_t>(pl_.data.size())) {
1337         const PipelineLayout& pipelineLayout = createInfo.pipelineLayout;
1338         PipelineLayout& ref = pl_.data[arrayIndex];
1339 #if (RENDER_VALIDATION_ENABLED == 1)
1340         if (pipelineLayout.descriptorSetCount > PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT &&
1341             pipelineLayout.pushConstant.byteSize > PipelineLayoutConstants::MAX_PUSH_CONSTANT_BYTE_SIZE) {
1342             PLUGIN_LOG_W(
1343                 "Invalid pipeline layout sizes clamped (name:%s). Set count %u <= %u, push constant size %u <= %u",
1344                 createInfo.path.data(), ref.descriptorSetCount, PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT,
1345                 pipelineLayout.pushConstant.byteSize, PipelineLayoutConstants::MAX_PUSH_CONSTANT_BYTE_SIZE);
1346         }
1347 #endif
1348         ref.pushConstant = pipelineLayout.pushConstant;
1349         ref.descriptorSetCount =
1350             Math::min(PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT, pipelineLayout.descriptorSetCount);
1351         ref.pushConstant.byteSize =
1352             Math::min(PipelineLayoutConstants::MAX_PUSH_CONSTANT_BYTE_SIZE, pipelineLayout.pushConstant.byteSize);
1353         uint32_t descriptorSetBitmask = 0;
1354         // can be user generated pipeline layout (i.e. set index might be different than index)
1355         for (uint32_t idx = 0; idx < PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT; ++idx) {
1356             const uint32_t setIdx = pipelineLayout.descriptorSetLayouts[idx].set;
1357             if (setIdx < PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT) {
1358                 ref.descriptorSetLayouts[setIdx] = pipelineLayout.descriptorSetLayouts[setIdx];
1359                 descriptorSetBitmask |= (1 << setIdx);
1360             }
1361         }
1362 
1363         const RenderHandle handle =
1364             RenderHandleUtil::CreateHandle(RenderHandleType::PIPELINE_LAYOUT, arrayIndex, 0, descriptorSetBitmask);
1365         pl_.rhr[arrayIndex] = RenderHandleReference(handle, IRenderReferenceCounter::Ptr(new ShaderReferenceCounter()));
1366         return pl_.rhr[arrayIndex];
1367     } else {
1368         return {};
1369     }
1370 }
1371 
GetGpuComputeProgram(const RenderHandle & handle) const1372 const GpuComputeProgram* ShaderManager::GetGpuComputeProgram(const RenderHandle& handle) const
1373 {
1374     if (!IsComputeShaderFunc(handle)) {
1375         PLUGIN_LOG_E("ShaderManager: invalid compute shader handle");
1376         return nullptr;
1377     }
1378     const uint32_t index = RenderHandleUtil::GetIndexPart(handle);
1379     if (index < static_cast<uint32_t>(computeShaders_.size())) {
1380         return computeShaders_[index].gsp.get();
1381     } else {
1382         PLUGIN_LOG_E("ShaderManager: invalid compute shader handle");
1383         return nullptr;
1384     }
1385 }
1386 
GetGpuShaderProgram(const RenderHandle & handle) const1387 const GpuShaderProgram* ShaderManager::GetGpuShaderProgram(const RenderHandle& handle) const
1388 {
1389     if (!IsShaderFunc(handle)) {
1390         PLUGIN_LOG_E("ShaderManager: invalid shader handle");
1391         return nullptr;
1392     }
1393     const uint32_t index = RenderHandleUtil::GetIndexPart(handle);
1394     if (index < static_cast<uint32_t>(shaders_.size())) {
1395         return shaders_[index].gsp.get();
1396     } else {
1397         PLUGIN_LOG_E("ShaderManager: invalid shader handle");
1398         return nullptr;
1399     }
1400 }
1401 
CreateShaderModule(const string_view name,const ShaderModuleCreateInfo & createInfo)1402 uint32_t ShaderManager::CreateShaderModule(const string_view name, const ShaderModuleCreateInfo& createInfo)
1403 {
1404     auto& nameToIdx = shaderModules_.nameToIndex;
1405     auto& modules = shaderModules_.shaderModules;
1406     if (auto iter = nameToIdx.find(name); iter != nameToIdx.end()) {
1407         PLUGIN_ASSERT(iter->second < modules.size());
1408         // inside core validation due to being very low info for common users
1409 #if (RENDER_VALIDATION_ENABLED == 1)
1410         PLUGIN_LOG_I("ShaderManager: re-creating shader module of name %s", name.data());
1411 #endif
1412         // check that we don't push the same indices multiple times
1413         bool found = false;
1414         for (const auto& ref : pendingAllocations_.recreatedShaderModuleIndices) {
1415             if (ref == iter->second) {
1416                 found = true;
1417                 break;
1418             }
1419         }
1420         if (!found) {
1421             pendingAllocations_.recreatedShaderModuleIndices.emplace_back(iter->second);
1422         }
1423         deferredDestructions_.shaderModules.push_back({ device_.GetFrameCount(), move(modules[iter->second]) });
1424         modules[iter->second] = device_.CreateShaderModule(createInfo);
1425         return iter->second;
1426     } else {
1427         const uint32_t idx = static_cast<uint32_t>(modules.size());
1428         if (!name.empty()) {
1429             nameToIdx[name] = idx;
1430         }
1431         modules.emplace_back(device_.CreateShaderModule(createInfo));
1432         return idx;
1433     }
1434 }
1435 
GetShaderModule(const uint32_t index) const1436 ShaderModule* ShaderManager::GetShaderModule(const uint32_t index) const
1437 {
1438     const auto& modules = shaderModules_.shaderModules;
1439     if (index < modules.size()) {
1440         return modules[index].get();
1441     } else {
1442         return nullptr;
1443     }
1444 }
1445 
GetShaderModuleIndex(const string_view name) const1446 uint32_t ShaderManager::GetShaderModuleIndex(const string_view name) const
1447 {
1448     const auto& nameToIdx = shaderModules_.nameToIndex;
1449     if (const auto iter = nameToIdx.find(name); iter != nameToIdx.cend()) {
1450         PLUGIN_ASSERT(iter->second < shaderModules_.shaderModules.size());
1451         return iter->second;
1452     } else {
1453         return INVALID_SM_INDEX;
1454     }
1455 }
1456 
IsComputeShader(const RenderHandleReference & handle) const1457 bool ShaderManager::IsComputeShader(const RenderHandleReference& handle) const
1458 {
1459     return IsComputeShaderFunc(handle.GetHandle());
1460 }
1461 
IsShader(const RenderHandleReference & handle) const1462 bool ShaderManager::IsShader(const RenderHandleReference& handle) const
1463 {
1464     return IsShaderFunc(handle.GetHandle());
1465 }
1466 
LoadShaderFiles(const ShaderFilePathDesc & desc)1467 void ShaderManager::LoadShaderFiles(const ShaderFilePathDesc& desc)
1468 {
1469     if (shaderLoader_) {
1470         shaderLoader_->Load(desc);
1471     }
1472 }
1473 
LoadShaderFile(const string_view uri)1474 void ShaderManager::LoadShaderFile(const string_view uri)
1475 {
1476     if (shaderLoader_ && (!uri.empty())) {
1477         shaderLoader_->LoadFile(uri, false);
1478     }
1479 }
1480 
UnloadShaderFiles(const ShaderFilePathDesc & desc)1481 void ShaderManager::UnloadShaderFiles(const ShaderFilePathDesc& desc) {}
1482 
ReloadShaderFile(const string_view uri)1483 void ShaderManager::ReloadShaderFile(const string_view uri)
1484 {
1485     if (shaderLoader_ && (!uri.empty())) {
1486         shaderLoader_->LoadFile(uri, true);
1487         hasReloadedShaders_ = true;
1488     }
1489 }
1490 
ReloadSpvFiles(const array_view<string> & spvFiles)1491 void ShaderManager::ReloadSpvFiles(const array_view<string>& spvFiles)
1492 {
1493     if (shaderLoader_ && (!spvFiles.empty())) {
1494         shaderLoader_->Reload(spvFiles);
1495         hasReloadedShaders_ = true;
1496     }
1497 }
1498 
HasReloadedShaders() const1499 bool ShaderManager::HasReloadedShaders() const
1500 {
1501     return hasReloadedShaders_;
1502 }
1503 
GetMaterialMetadata(const RenderHandleReference & handle) const1504 const json::value* ShaderManager::GetMaterialMetadata(const RenderHandleReference& handle) const
1505 {
1506     if (const auto iter = shaderToMetadata_.find(handle.GetHandle()); iter != shaderToMetadata_.end()) {
1507         return &iter->second.json;
1508     }
1509     return nullptr;
1510 }
1511 
DestroyShader(const RenderHandle handle)1512 void ShaderManager::DestroyShader(const RenderHandle handle)
1513 {
1514     auto eraseIndexData = [](auto& mapStore, const RenderHandle handle) {
1515         if (auto const pos = std::find_if(
1516                 mapStore.begin(), mapStore.end(), [handle](auto const& element) { return element.second == handle; });
1517             pos != mapStore.end()) {
1518             mapStore.erase(pos);
1519         }
1520     };
1521 
1522     const uint32_t index = RenderHandleUtil::GetIndexPart(handle);
1523     if (IsComputeShaderFunc(handle)) {
1524         auto& mappings = computeShaderMappings_;
1525         if (index < static_cast<uint32_t>(mappings.clientData.size())) {
1526             mappings.clientData[index] = {};
1527             eraseIndexData(nameToClientHandle_, handle);
1528             {
1529                 const auto lock = std::lock_guard(pendingMutex_);
1530                 pendingAllocations_.destroyHandles.emplace_back(handle);
1531             }
1532         }
1533     } else if (IsShaderFunc(handle)) {
1534         auto& mappings = shaderMappings_;
1535         if (index < static_cast<uint32_t>(mappings.clientData.size())) {
1536             mappings.clientData[index] = {};
1537             eraseIndexData(nameToClientHandle_, handle);
1538             {
1539                 const auto lock = std::lock_guard(pendingMutex_);
1540                 pendingAllocations_.destroyHandles.emplace_back(handle);
1541             }
1542         }
1543     }
1544 }
1545 
Destroy(const RenderHandleReference & handle)1546 void ShaderManager::Destroy(const RenderHandleReference& handle)
1547 {
1548     const RenderHandle rawHandle = handle.GetHandle();
1549     const RenderHandleType handleType = RenderHandleUtil::GetHandleType(rawHandle);
1550     if ((handleType == RenderHandleType::COMPUTE_SHADER_STATE_OBJECT) ||
1551         (handleType == RenderHandleType::SHADER_STATE_OBJECT)) {
1552         DestroyShader(rawHandle);
1553     } else if (handleType == RenderHandleType::GRAPHICS_STATE) {
1554         DestroyGraphicsState(rawHandle);
1555     } else if (handleType == RenderHandleType::PIPELINE_LAYOUT) {
1556         DestroyPipelineLayout(rawHandle);
1557     } else if (handleType == RenderHandleType::VERTEX_INPUT_DECLARATION) {
1558         DestroyVertexInputDeclaration(rawHandle);
1559     }
1560 }
1561 
DestroyGraphicsState(const RenderHandle handle)1562 void ShaderManager::DestroyGraphicsState(const RenderHandle handle)
1563 {
1564     const uint32_t index = RenderHandleUtil::GetIndexPart(handle);
1565     if (index < static_cast<uint32_t>(graphicsStates_.rhr.size())) {
1566         graphicsStates_.rhr[index] = {};
1567         graphicsStates_.data[index] = {};
1568         graphicsStates_.graphicsStates[index] = {};
1569 
1570         auto eraseIndexData = [](auto& mapStore, const uint32_t index) {
1571             if (auto const pos = std::find_if(
1572                     mapStore.begin(), mapStore.end(), [index](auto const& element) { return element.second == index; });
1573                 pos != mapStore.end()) {
1574                 mapStore.erase(pos);
1575             }
1576         };
1577         eraseIndexData(graphicsStates_.nameToIndex, index);
1578         eraseIndexData(graphicsStates_.hashToIndex, index);
1579         // NOTE: shaderToStates needs to be added
1580     }
1581 }
1582 
DestroyPipelineLayout(const RenderHandle handle)1583 void ShaderManager::DestroyPipelineLayout(const RenderHandle handle)
1584 {
1585     const uint32_t index = RenderHandleUtil::GetIndexPart(handle);
1586     if (index < static_cast<uint32_t>(pl_.rhr.size())) {
1587         pl_.rhr[index] = {};
1588         pl_.data[index] = {};
1589 
1590         auto eraseIndexData = [](auto& mapStore, const uint32_t index) {
1591             if (auto const pos = std::find_if(
1592                     mapStore.begin(), mapStore.end(), [index](auto const& element) { return element.second == index; });
1593                 pos != mapStore.end()) {
1594                 mapStore.erase(pos);
1595             }
1596         };
1597         eraseIndexData(pl_.nameToIndex, index);
1598         eraseIndexData(pl_.computeShaderToIndex, index);
1599         eraseIndexData(pl_.shaderToIndex, index);
1600     }
1601 }
1602 
DestroyVertexInputDeclaration(const RenderHandle handle)1603 void ShaderManager::DestroyVertexInputDeclaration(const RenderHandle handle)
1604 {
1605     const uint32_t index = RenderHandleUtil::GetIndexPart(handle);
1606     if (index < static_cast<uint32_t>(shaderVid_.rhr.size())) {
1607         shaderVid_.rhr[index] = {};
1608         shaderVid_.data[index] = {};
1609 
1610         auto eraseIndexData = [](auto& mapStore, const uint32_t index) {
1611             if (auto const pos = std::find_if(
1612                     mapStore.begin(), mapStore.end(), [index](auto const& element) { return element.second == index; });
1613                 pos != mapStore.end()) {
1614                 mapStore.erase(pos);
1615             }
1616         };
1617         eraseIndexData(shaderVid_.nameToIndex, index);
1618         eraseIndexData(shaderVid_.shaderToIndex, index);
1619     }
1620 }
1621 
GetShaders(const RenderHandleReference & handle,const ShaderStageFlags shaderStageFlags) const1622 vector<RenderHandleReference> ShaderManager::GetShaders(
1623     const RenderHandleReference& handle, const ShaderStageFlags shaderStageFlags) const
1624 {
1625     vector<RenderHandleReference> shaders;
1626     if ((shaderStageFlags &
1627             (CORE_SHADER_STAGE_VERTEX_BIT | CORE_SHADER_STAGE_FRAGMENT_BIT | CORE_SHADER_STAGE_COMPUTE_BIT)) == 0) {
1628         return shaders;
1629     }
1630     const RenderHandleType handleType = handle.GetHandleType();
1631     const uint32_t handleIndex = RenderHandleUtil::GetIndexPart(handle.GetHandle());
1632     if (handleType == RenderHandleType::GRAPHICS_STATE) {
1633 #if (RENDER_VALIDATION_ENABLED == 1)
1634         PLUGIN_LOG_W("RENDER_VALIDATION: GetShaders with graphics state handle not supported");
1635 #endif
1636     } else if ((handleType == RenderHandleType::PIPELINE_LAYOUT) ||
1637                (handleType == RenderHandleType::VERTEX_INPUT_DECLARATION)) {
1638         if (shaderStageFlags & ShaderStageFlagBits::CORE_SHADER_STAGE_COMPUTE_BIT) {
1639             for (const auto& ref : computeShaderMappings_.clientData) {
1640                 if (ref.pipelineLayoutIndex == handleIndex) {
1641                     shaders.emplace_back(ref.rhr);
1642                 }
1643             }
1644         }
1645         if (shaderStageFlags & ShaderStageFlagBits::CORE_SHADER_STAGE_ALL_GRAPHICS) {
1646             for (const auto& ref : shaderMappings_.clientData) {
1647                 if (ref.vertexInputDeclarationIndex == handleIndex) {
1648                     shaders.emplace_back(ref.rhr);
1649                 }
1650             }
1651         }
1652     }
1653     return shaders;
1654 }
1655 
GetShaders(const RenderHandle & handle,const ShaderStageFlags shaderStageFlags) const1656 vector<RenderHandle> ShaderManager::GetShaders(
1657     const RenderHandle& handle, const ShaderStageFlags shaderStageFlags) const
1658 {
1659     vector<RenderHandle> shaders;
1660     if ((shaderStageFlags &
1661             (CORE_SHADER_STAGE_VERTEX_BIT | CORE_SHADER_STAGE_FRAGMENT_BIT | CORE_SHADER_STAGE_COMPUTE_BIT)) == 0) {
1662         return shaders;
1663     }
1664     const RenderHandleType handleType = RenderHandleUtil::GetHandleType(handle);
1665     const uint32_t handleIndex = RenderHandleUtil::GetIndexPart(handle);
1666     if (handleType == RenderHandleType::GRAPHICS_STATE) {
1667 #if (RENDER_VALIDATION_ENABLED == 1)
1668         PLUGIN_LOG_W("RENDER_VALIDATION: GetShaders with graphics state handle not supported");
1669 #endif
1670     } else if ((handleType == RenderHandleType::PIPELINE_LAYOUT) ||
1671                (handleType == RenderHandleType::VERTEX_INPUT_DECLARATION)) {
1672         if (shaderStageFlags & ShaderStageFlagBits::CORE_SHADER_STAGE_COMPUTE_BIT) {
1673             for (const auto& ref : computeShaderMappings_.clientData) {
1674                 if (ref.pipelineLayoutIndex == handleIndex) {
1675                     shaders.emplace_back(ref.rhr.GetHandle());
1676                 }
1677             }
1678         }
1679         if (shaderStageFlags & ShaderStageFlagBits::CORE_SHADER_STAGE_ALL_GRAPHICS) {
1680             for (const auto& ref : shaderMappings_.clientData) {
1681                 if (ref.vertexInputDeclarationIndex == handleIndex) {
1682                     shaders.emplace_back(ref.rhr.GetHandle());
1683                 }
1684             }
1685         }
1686     }
1687     return shaders;
1688 }
1689 
GetShaderIdDesc(const RenderHandle handle) const1690 IShaderManager::IdDesc ShaderManager::GetShaderIdDesc(const RenderHandle handle) const
1691 {
1692     const RenderHandleType handleType = RenderHandleUtil::GetHandleType(handle);
1693     const uint32_t index = RenderHandleUtil::GetIndexPart(handle);
1694     IdDesc desc;
1695     if ((handleType == RenderHandleType::COMPUTE_SHADER_STATE_OBJECT) &&
1696         (index < static_cast<uint32_t>(computeShaderMappings_.clientData.size()))) {
1697         desc.renderSlotId = computeShaderMappings_.clientData[index].renderSlotId;
1698         for (const auto& ref : nameToClientHandle_) {
1699             if (ref.second == handle) {
1700                 desc.path = ref.first;
1701             }
1702         }
1703     } else if ((handleType == RenderHandleType::SHADER_STATE_OBJECT) &&
1704                (index < static_cast<uint32_t>(shaderMappings_.clientData.size()))) {
1705         desc.renderSlotId = shaderMappings_.clientData[index].renderSlotId;
1706         for (const auto& ref : nameToClientHandle_) {
1707             if (ref.second == handle) {
1708                 desc.path = ref.first;
1709             }
1710         }
1711     }
1712     return desc;
1713 }
1714 
GetIdDesc(const RenderHandleReference & handle) const1715 IShaderManager::IdDesc ShaderManager::GetIdDesc(const RenderHandleReference& handle) const
1716 {
1717     auto GetIdDesc = [](const auto& nameToIndex, const auto handleIndex) {
1718         IdDesc desc;
1719         for (const auto& ref : nameToIndex) {
1720             if (ref.second == handleIndex) {
1721                 desc.path = ref.first;
1722             }
1723         }
1724         return desc;
1725     };
1726     const RenderHandle rawHandle = handle.GetHandle();
1727     const RenderHandleType handleType = RenderHandleUtil::GetHandleType(rawHandle);
1728     const uint32_t handleIndex = RenderHandleUtil::GetIndexPart(rawHandle);
1729     IdDesc desc;
1730     if ((handleType == RenderHandleType::COMPUTE_SHADER_STATE_OBJECT) ||
1731         (handleType == RenderHandleType::SHADER_STATE_OBJECT)) {
1732         desc = GetShaderIdDesc(rawHandle);
1733     } else if ((handleType == RenderHandleType::GRAPHICS_STATE) && (handleIndex < graphicsStates_.rhr.size())) {
1734         desc = GetIdDesc(graphicsStates_.nameToIndex, handleIndex);
1735     } else if ((handleType == RenderHandleType::PIPELINE_LAYOUT) && (handleIndex < pl_.rhr.size())) {
1736         desc = GetIdDesc(pl_.nameToIndex, handleIndex);
1737     } else if ((handleType == RenderHandleType::VERTEX_INPUT_DECLARATION) && (handleIndex < shaderVid_.rhr.size())) {
1738         desc = GetIdDesc(shaderVid_.nameToIndex, handleIndex);
1739     }
1740     return desc;
1741 }
1742 
CreateShaderPipelineBinder(const RenderHandleReference & handle) const1743 IShaderPipelineBinder::Ptr ShaderManager::CreateShaderPipelineBinder(const RenderHandleReference& handle) const
1744 {
1745     const RenderHandleType type = handle.GetHandleType();
1746     if (handle &&
1747         ((type == RenderHandleType::SHADER_STATE_OBJECT) || (type == RenderHandleType::COMPUTE_SHADER_STATE_OBJECT))) {
1748         return IShaderPipelineBinder::Ptr { new ShaderPipelineBinder(handle, GetReflectionPipelineLayout(handle)) };
1749     }
1750     return nullptr;
1751 }
1752 
GetCompatibilityFlags(const RenderHandle & lhs,const RenderHandle & rhs) const1753 ShaderManager::CompatibilityFlags ShaderManager::GetCompatibilityFlags(
1754     const RenderHandle& lhs, const RenderHandle& rhs) const
1755 {
1756     const RenderHandleType lType = RenderHandleUtil::GetHandleType(lhs);
1757     const RenderHandleType rType = RenderHandleUtil::GetHandleType(rhs);
1758     CompatibilityFlags flags = 0;
1759     // NOTE: only same types supported at the moment
1760     if (lType == rType) {
1761         if (lType == RenderHandleType::PIPELINE_LAYOUT) {
1762             const PipelineLayout lpl = GetPipelineLayout(lhs);
1763             const PipelineLayout rpl = GetPipelineLayout(rhs);
1764             flags = GetPipelineLayoutCompatibilityFlags(lpl, rpl);
1765         } else if ((lType == RenderHandleType::SHADER_STATE_OBJECT) ||
1766                    (lType == RenderHandleType::COMPUTE_SHADER_STATE_OBJECT)) {
1767             // first check that given pipeline layout is valid to own reflection
1768             const RenderHandle shaderPlHandle = GetPipelineLayoutHandleByShaderHandle(rhs).GetHandle();
1769             if (RenderHandleUtil::IsValid(shaderPlHandle)) {
1770                 const PipelineLayout shaderPl = GetPipelineLayout(shaderPlHandle);
1771                 const PipelineLayout rpl = GetReflectionPipelineLayoutRef(rhs);
1772                 if (rpl.descriptorSetCount > 0) {
1773                     flags = GetPipelineLayoutCompatibilityFlags(rpl, shaderPl);
1774                 }
1775             }
1776             // then, compare to lhs with rhs reflection
1777             if (flags != 0) {
1778                 const RenderHandle lShaderPlHandle = GetPipelineLayoutHandleByShaderHandle(lhs).GetHandle();
1779                 const PipelineLayout lpl = RenderHandleUtil::IsValid(lShaderPlHandle)
1780                                                ? GetPipelineLayout(lShaderPlHandle)
1781                                                : GetReflectionPipelineLayoutRef(lhs);
1782                 flags = GetPipelineLayoutCompatibilityFlags(lpl, GetReflectionPipelineLayoutRef(rhs));
1783             }
1784         }
1785     }
1786     return flags;
1787 }
1788 
GetCompatibilityFlags(const RenderHandleReference & lhs,const RenderHandleReference & rhs) const1789 ShaderManager::CompatibilityFlags ShaderManager::GetCompatibilityFlags(
1790     const RenderHandleReference& lhs, const RenderHandleReference& rhs) const
1791 {
1792     if (lhs && rhs) {
1793         return GetCompatibilityFlags(lhs.GetHandle(), rhs.GetHandle());
1794     } else {
1795         return CompatibilityFlags { 0 };
1796     }
1797 }
1798 
SetFileManager(IFileManager & fileMgr)1799 void ShaderManager::SetFileManager(IFileManager& fileMgr)
1800 {
1801     fileMgr_ = &fileMgr;
1802     shaderLoader_ = make_unique<ShaderLoader>(*fileMgr_, *this, device_.GetBackendType());
1803 }
1804 
1805 constexpr uint8_t REFLECTION_TAG[] = { 'r', 'f', 'l', 0 };
1806 struct ReflectionHeader {
1807     uint8_t tag[sizeof(REFLECTION_TAG)];
1808     uint16_t type;
1809     uint16_t offsetPushConstants;
1810     uint16_t offsetSpecializationConstants;
1811     uint16_t offsetDescriptorSets;
1812     uint16_t offsetInputs;
1813     uint16_t offsetLocalSize;
1814 };
1815 
IsValid() const1816 bool ShaderReflectionData::IsValid() const
1817 {
1818     if (reflectionData.size() < sizeof(ReflectionHeader)) {
1819         return false;
1820     }
1821     const ReflectionHeader& header = *reinterpret_cast<const ReflectionHeader*>(reflectionData.data());
1822     return memcmp(header.tag, REFLECTION_TAG, sizeof(REFLECTION_TAG)) == 0;
1823 }
1824 
GetStageFlags() const1825 ShaderStageFlags ShaderReflectionData::GetStageFlags() const
1826 {
1827     ShaderStageFlags flags;
1828     const ReflectionHeader& header = *reinterpret_cast<const ReflectionHeader*>(reflectionData.data());
1829     flags = header.type;
1830     return flags;
1831 }
1832 
GetPipelineLayout() const1833 PipelineLayout ShaderReflectionData::GetPipelineLayout() const
1834 {
1835     PipelineLayout pipelineLayout;
1836     const ReflectionHeader& header = *reinterpret_cast<const ReflectionHeader*>(reflectionData.data());
1837     if (header.offsetPushConstants && header.offsetPushConstants < reflectionData.size()) {
1838         auto ptr = reflectionData.data() + header.offsetPushConstants;
1839         const auto constants = *ptr;
1840         if (constants) {
1841             pipelineLayout.pushConstant.shaderStageFlags = header.type;
1842             pipelineLayout.pushConstant.byteSize = static_cast<uint32_t>(*(ptr + 1) | (*(ptr + 2) << 8));
1843         }
1844     }
1845     if (header.offsetDescriptorSets && header.offsetDescriptorSets < reflectionData.size()) {
1846         auto ptr = reflectionData.data() + header.offsetDescriptorSets;
1847         pipelineLayout.descriptorSetCount = static_cast<uint32_t>(*(ptr) | (*(ptr + 1) << 8));
1848         ptr += 2;
1849         for (auto i = 0u; i < pipelineLayout.descriptorSetCount; ++i) {
1850             // write to correct set location
1851             const uint32_t set = static_cast<uint32_t>(*ptr | (*(ptr + 1) << 8));
1852             PLUGIN_ASSERT(set < PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT);
1853             auto& layout = pipelineLayout.descriptorSetLayouts[set];
1854             layout.set = set;
1855             ptr += 2;
1856             const auto bindings = static_cast<uint32_t>(*ptr | (*(ptr + 1) << 8));
1857             ptr += 2;
1858             for (auto j = 0u; j < bindings; ++j) {
1859                 DescriptorSetLayoutBinding binding;
1860                 binding.binding = static_cast<uint32_t>(*ptr | (*(ptr + 1) << 8));
1861                 ptr += 2;
1862                 binding.descriptorType = static_cast<DescriptorType>(*ptr | (*(ptr + 1) << 8));
1863                 if ((binding.descriptorType > DescriptorType::CORE_DESCRIPTOR_TYPE_INPUT_ATTACHMENT) &&
1864                     (binding.descriptorType ==
1865                         (DescriptorType::CORE_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE & 0xffff))) {
1866                     binding.descriptorType = DescriptorType::CORE_DESCRIPTOR_TYPE_ACCELERATION_STRUCTURE;
1867                 }
1868                 ptr += 2;
1869                 binding.descriptorCount = static_cast<uint32_t>(*ptr | (*(ptr + 1) << 8));
1870                 ptr += 2;
1871                 binding.shaderStageFlags = header.type;
1872                 layout.bindings.push_back(binding);
1873             }
1874         }
1875     }
1876     return pipelineLayout;
1877 }
1878 
GetSpecializationConstants() const1879 vector<ShaderSpecialization::Constant> ShaderReflectionData::GetSpecializationConstants() const
1880 {
1881     vector<ShaderSpecialization::Constant> constants;
1882     const ReflectionHeader& header = *reinterpret_cast<const ReflectionHeader*>(reflectionData.data());
1883     if (header.offsetSpecializationConstants && header.offsetSpecializationConstants < reflectionData.size()) {
1884         auto ptr = reflectionData.data() + header.offsetSpecializationConstants;
1885         const auto size = *ptr | *(ptr + 1) << 8 | *(ptr + 2) << 16 | *(ptr + 3) << 24;
1886         ptr += 4;
1887         for (auto i = 0; i < size; ++i) {
1888             ShaderSpecialization::Constant constant;
1889             constant.shaderStage = header.type;
1890             constant.id = static_cast<uint32_t>(*ptr | *(ptr + 1) << 8 | *(ptr + 2) << 16 | *(ptr + 3) << 24);
1891             ptr += 4;
1892             constant.type = static_cast<ShaderSpecialization::Constant::Type>(
1893                 *ptr | *(ptr + 1) << 8 | *(ptr + 2) << 16 | *(ptr + 3) << 24);
1894             ptr += 4;
1895             constant.offset = 0;
1896             constants.push_back(constant);
1897         }
1898     }
1899     return constants;
1900 }
1901 
GetInputDescriptions() const1902 vector<VertexInputDeclaration::VertexInputAttributeDescription> ShaderReflectionData::GetInputDescriptions() const
1903 {
1904     vector<VertexInputDeclaration::VertexInputAttributeDescription> inputs;
1905     const ReflectionHeader& header = *reinterpret_cast<const ReflectionHeader*>(reflectionData.data());
1906     if (header.offsetInputs && header.offsetInputs < reflectionData.size()) {
1907         auto ptr = reflectionData.data() + header.offsetInputs;
1908         const auto size = *(ptr) | (*(ptr + 1) << 8);
1909         ptr += 2;
1910         for (auto i = 0; i < size; ++i) {
1911             VertexInputDeclaration::VertexInputAttributeDescription desc;
1912             desc.location = static_cast<uint32_t>(*(ptr) | (*(ptr + 1) << 8));
1913             ptr += 2;
1914             desc.binding = desc.location;
1915             desc.format = static_cast<Format>(*(ptr) | (*(ptr + 1) << 8));
1916             ptr += 2;
1917             desc.offset = 0;
1918             inputs.push_back(desc);
1919         }
1920     }
1921     return inputs;
1922 }
1923 
GetLocalSize() const1924 Math::UVec3 ShaderReflectionData::GetLocalSize() const
1925 {
1926     Math::UVec3 sizes;
1927     const ReflectionHeader& header = *reinterpret_cast<const ReflectionHeader*>(reflectionData.data());
1928     if (header.offsetLocalSize && header.offsetLocalSize < reflectionData.size()) {
1929         auto ptr = reflectionData.data() + header.offsetLocalSize;
1930         sizes.x = static_cast<uint32_t>(*ptr | (*(ptr + 1) << 8) | (*(ptr + 2)) << 16 | (*(ptr + 3)) << 24);
1931         ptr += 4;
1932         sizes.y = static_cast<uint32_t>(*ptr | (*(ptr + 1) << 8) | (*(ptr + 2)) << 16 | (*(ptr + 3)) << 24);
1933         ptr += 4;
1934         sizes.z = static_cast<uint32_t>(*ptr | (*(ptr + 1) << 8) | (*(ptr + 2)) << 16 | (*(ptr + 3)) << 24);
1935     }
1936     return sizes;
1937 }
1938 
GetPushConstants() const1939 const uint8_t* ShaderReflectionData::GetPushConstants() const
1940 {
1941     const uint8_t* ptr = nullptr;
1942     const ReflectionHeader& header = *reinterpret_cast<const ReflectionHeader*>(reflectionData.data());
1943     if (header.offsetPushConstants && header.offsetPushConstants < reflectionData.size()) {
1944         const auto constants = *(reflectionData.data() + header.offsetPushConstants);
1945         if (constants) {
1946             // number of constants is uint8 and the size of the constant is uint16
1947             ptr = reflectionData.data() + header.offsetPushConstants + sizeof(uint8_t) + sizeof(uint16_t);
1948         }
1949     }
1950     return ptr;
1951 }
1952 
RenderNodeShaderManager(const ShaderManager & shaderMgr)1953 RenderNodeShaderManager::RenderNodeShaderManager(const ShaderManager& shaderMgr) : shaderMgr_(shaderMgr) {}
1954 
GetShaderHandle(const string_view name) const1955 RenderHandle RenderNodeShaderManager::GetShaderHandle(const string_view name) const
1956 {
1957     return shaderMgr_.GetShaderHandle(name).GetHandle();
1958 }
1959 
GetShaderHandle(const string_view name,const string_view variantName) const1960 RenderHandle RenderNodeShaderManager::GetShaderHandle(const string_view name, const string_view variantName) const
1961 {
1962     return shaderMgr_.GetShaderHandle(name, variantName).GetHandle();
1963 }
1964 
GetShaderHandle(const RenderHandle & handle,const uint32_t renderSlotId) const1965 RenderHandle RenderNodeShaderManager::GetShaderHandle(const RenderHandle& handle, const uint32_t renderSlotId) const
1966 {
1967     return shaderMgr_.GetShaderHandle(handle, renderSlotId).GetHandle();
1968 }
1969 
GetShaders(const uint32_t renderSlotId) const1970 vector<RenderHandle> RenderNodeShaderManager::GetShaders(const uint32_t renderSlotId) const
1971 {
1972     return shaderMgr_.GetShaderRawHandles(renderSlotId);
1973 }
1974 
GetGraphicsStateHandle(const string_view name) const1975 RenderHandle RenderNodeShaderManager::GetGraphicsStateHandle(const string_view name) const
1976 {
1977     return shaderMgr_.GetGraphicsStateHandle(name).GetHandle();
1978 }
1979 
GetGraphicsStateHandle(const string_view name,const string_view variantName) const1980 RenderHandle RenderNodeShaderManager::GetGraphicsStateHandle(
1981     const string_view name, const string_view variantName) const
1982 {
1983     return shaderMgr_.GetGraphicsStateHandle(name, variantName).GetHandle();
1984 }
1985 
GetGraphicsStateHandle(const RenderHandle & handle,const uint32_t renderSlotId) const1986 RenderHandle RenderNodeShaderManager::GetGraphicsStateHandle(
1987     const RenderHandle& handle, const uint32_t renderSlotId) const
1988 {
1989     return shaderMgr_.GetGraphicsStateHandle(handle, renderSlotId).GetHandle();
1990 }
1991 
GetGraphicsStateHandleByHash(const uint64_t hash) const1992 RenderHandle RenderNodeShaderManager::GetGraphicsStateHandleByHash(const uint64_t hash) const
1993 {
1994     return shaderMgr_.GetGraphicsStateHandleByHash(hash).GetHandle();
1995 }
1996 
GetGraphicsStateHandleByShaderHandle(const RenderHandle & handle) const1997 RenderHandle RenderNodeShaderManager::GetGraphicsStateHandleByShaderHandle(const RenderHandle& handle) const
1998 {
1999     return shaderMgr_.GetGraphicsStateHandleByShaderHandle(handle).GetHandle();
2000 }
2001 
GetGraphicsState(const RenderHandle & handle) const2002 const GraphicsState& RenderNodeShaderManager::GetGraphicsState(const RenderHandle& handle) const
2003 {
2004     return shaderMgr_.GetGraphicsStateRef(handle);
2005 }
2006 
GetRenderSlotId(const string_view renderSlot) const2007 uint32_t RenderNodeShaderManager::GetRenderSlotId(const string_view renderSlot) const
2008 {
2009     return shaderMgr_.GetRenderSlotId(renderSlot);
2010 }
2011 
GetRenderSlotId(const RenderHandle & handle) const2012 uint32_t RenderNodeShaderManager::GetRenderSlotId(const RenderHandle& handle) const
2013 {
2014     return shaderMgr_.GetRenderSlotId(handle);
2015 }
2016 
GetRenderSlotData(const uint32_t renderSlotId) const2017 IShaderManager::RenderSlotData RenderNodeShaderManager::GetRenderSlotData(const uint32_t renderSlotId) const
2018 {
2019     return shaderMgr_.GetRenderSlotData(renderSlotId);
2020 }
2021 
GetVertexInputDeclarationHandleByShaderHandle(const RenderHandle & handle) const2022 RenderHandle RenderNodeShaderManager::GetVertexInputDeclarationHandleByShaderHandle(const RenderHandle& handle) const
2023 {
2024     return shaderMgr_.GetVertexInputDeclarationHandleByShaderHandle(handle).GetHandle();
2025 }
2026 
GetVertexInputDeclarationHandle(const string_view name) const2027 RenderHandle RenderNodeShaderManager::GetVertexInputDeclarationHandle(const string_view name) const
2028 {
2029     return shaderMgr_.GetVertexInputDeclarationHandle(name).GetHandle();
2030 }
2031 
GetVertexInputDeclarationView(const RenderHandle & handle) const2032 VertexInputDeclarationView RenderNodeShaderManager::GetVertexInputDeclarationView(const RenderHandle& handle) const
2033 {
2034     return shaderMgr_.GetVertexInputDeclarationView(handle);
2035 }
2036 
GetPipelineLayoutHandleByShaderHandle(const RenderHandle & handle) const2037 RenderHandle RenderNodeShaderManager::GetPipelineLayoutHandleByShaderHandle(const RenderHandle& handle) const
2038 {
2039     return shaderMgr_.GetPipelineLayoutHandleByShaderHandle(handle).GetHandle();
2040 }
2041 
GetPipelineLayout(const RenderHandle & handle) const2042 const PipelineLayout& RenderNodeShaderManager::GetPipelineLayout(const RenderHandle& handle) const
2043 {
2044     return shaderMgr_.GetPipelineLayoutRef(handle);
2045 }
2046 
GetPipelineLayoutHandle(const string_view name) const2047 RenderHandle RenderNodeShaderManager::GetPipelineLayoutHandle(const string_view name) const
2048 {
2049     return shaderMgr_.GetPipelineLayoutHandle(name).GetHandle();
2050 }
2051 
GetReflectionPipelineLayoutHandle(const RenderHandle & handle) const2052 RenderHandle RenderNodeShaderManager::GetReflectionPipelineLayoutHandle(const RenderHandle& handle) const
2053 {
2054     return shaderMgr_.GetReflectionPipelineLayoutHandle(handle).GetHandle();
2055 }
2056 
GetReflectionPipelineLayout(const RenderHandle & handle) const2057 const PipelineLayout& RenderNodeShaderManager::GetReflectionPipelineLayout(const RenderHandle& handle) const
2058 {
2059     return shaderMgr_.GetReflectionPipelineLayoutRef(handle);
2060 }
2061 
GetReflectionSpecialization(const RenderHandle & handle) const2062 ShaderSpecilizationConstantView RenderNodeShaderManager::GetReflectionSpecialization(const RenderHandle& handle) const
2063 {
2064     return shaderMgr_.GetReflectionSpecialization(handle);
2065 }
2066 
GetReflectionVertexInputDeclaration(const RenderHandle & handle) const2067 VertexInputDeclarationView RenderNodeShaderManager::GetReflectionVertexInputDeclaration(
2068     const RenderHandle& handle) const
2069 {
2070     return shaderMgr_.GetReflectionVertexInputDeclaration(handle);
2071 }
2072 
GetReflectionThreadGroupSize(const RenderHandle & handle) const2073 ShaderThreadGroup RenderNodeShaderManager::GetReflectionThreadGroupSize(const RenderHandle& handle) const
2074 {
2075     return shaderMgr_.GetReflectionThreadGroupSize(handle);
2076 }
2077 
HashGraphicsState(const GraphicsState & graphicsState) const2078 uint64_t RenderNodeShaderManager::HashGraphicsState(const GraphicsState& graphicsState) const
2079 {
2080     return shaderMgr_.HashGraphicsState(graphicsState);
2081 }
2082 
IsValid(const RenderHandle & handle) const2083 bool RenderNodeShaderManager::IsValid(const RenderHandle& handle) const
2084 {
2085     return RenderHandleUtil::IsValid(handle);
2086 }
2087 
IsComputeShader(const RenderHandle & handle) const2088 bool RenderNodeShaderManager::IsComputeShader(const RenderHandle& handle) const
2089 {
2090     return IsComputeShaderFunc(handle);
2091 }
2092 
IsShader(const RenderHandle & handle) const2093 bool RenderNodeShaderManager::IsShader(const RenderHandle& handle) const
2094 {
2095     return IsShaderFunc(handle);
2096 }
2097 
GetShaders(const RenderHandle & handle,const ShaderStageFlags shaderStageFlags) const2098 vector<RenderHandle> RenderNodeShaderManager::GetShaders(
2099     const RenderHandle& handle, const ShaderStageFlags shaderStageFlags) const
2100 {
2101     return shaderMgr_.GetShaders(handle, shaderStageFlags);
2102 }
2103 
GetCompatibilityFlags(const RenderHandle & lhs,const RenderHandle & rhs) const2104 IShaderManager::CompatibilityFlags RenderNodeShaderManager::GetCompatibilityFlags(
2105     const RenderHandle& lhs, const RenderHandle& rhs) const
2106 {
2107     return shaderMgr_.GetCompatibilityFlags(lhs, rhs);
2108 }
2109 RENDER_END_NAMESPACE()
2110