1 /*
2 * Copyright (C) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "render_node_compute_generic.h"
17
18 #include <base/math/mathf.h>
19 #include <base/math/vector.h>
20 #include <render/datastore/intf_render_data_store_manager.h>
21 #include <render/datastore/intf_render_data_store_pod.h>
22 #include <render/device/intf_gpu_resource_manager.h>
23 #include <render/device/intf_shader_manager.h>
24 #include <render/namespace.h>
25 #include <render/nodecontext/intf_node_context_descriptor_set_manager.h>
26 #include <render/nodecontext/intf_node_context_pso_manager.h>
27 #include <render/nodecontext/intf_pipeline_descriptor_set_binder.h>
28 #include <render/nodecontext/intf_render_command_list.h>
29 #include <render/nodecontext/intf_render_node_context_manager.h>
30 #include <render/nodecontext/intf_render_node_parser_util.h>
31 #include <render/nodecontext/intf_render_node_util.h>
32
33 #include "util/log.h"
34
35 using namespace BASE_NS;
36
RENDER_BEGIN_NAMESPACE()37 RENDER_BEGIN_NAMESPACE()
38 void RenderNodeComputeGeneric::InitNode(IRenderNodeContextManager& renderNodeContextMgr)
39 {
40 renderNodeContextMgr_ = &renderNodeContextMgr;
41 ParseRenderNodeInputs();
42
43 useDataStoreShaderSpecialization_ = !jsonInputs_.renderDataStoreSpecialization.dataStoreName.empty();
44
45 auto& shaderMgr = renderNodeContextMgr.GetShaderManager();
46 const auto& renderNodeUtil = renderNodeContextMgr.GetRenderNodeUtil();
47 if (!shaderMgr.IsValid(shader_)) {
48 PLUGIN_LOG_E("RenderNodeComputeGeneric needs a valid shader handle");
49 }
50
51 pipelineLayout_ = renderNodeContextMgr.GetRenderNodeUtil().CreatePipelineLayout(shader_);
52 threadGroupSize_ = renderNodeContextMgr.GetShaderManager().GetReflectionThreadGroupSize(shader_);
53
54 const auto& gpuResourceMgr = renderNodeContextMgr.GetGpuResourceManager();
55 RenderHandle targetHandle;
56 for (const auto& imageRef : inputResources_.images) {
57 if (imageRef.set <= pipelineLayout_.descriptorSetCount) {
58 const auto& setRef = pipelineLayout_.descriptorSetLayouts[imageRef.set];
59 if (imageRef.binding < setRef.bindings.size()) {
60 const DescriptorType dt = setRef.bindings[imageRef.binding].descriptorType;
61 if (dt == DescriptorType::CORE_DESCRIPTOR_TYPE_STORAGE_IMAGE) {
62 const GpuImageDesc desc = gpuResourceMgr.GetImageDescriptor(imageRef.handle);
63 targetSize_.x = desc.width;
64 targetSize_.y = desc.height;
65 targetSize_.z = desc.depth;
66 break;
67 }
68 }
69 }
70 }
71 if (!RenderHandleUtil::IsValid(targetHandle)) {
72 PLUGIN_LOG_W("RenderNodeComputeGeneric: cannot automatically determ target size");
73 }
74
75 if (useDataStoreShaderSpecialization_) {
76 const ShaderSpecilizationConstantView sscv =
77 renderNodeContextMgr.GetShaderManager().GetReflectionSpecialization(shader_);
78 shaderSpecializationData_.constants.resize(sscv.constants.size());
79 shaderSpecializationData_.data.resize(sscv.constants.size());
80 for (size_t idx = 0; idx < shaderSpecializationData_.constants.size(); ++idx) {
81 shaderSpecializationData_.constants[idx] = sscv.constants[idx];
82 shaderSpecializationData_.data[idx] = ~0u;
83 }
84 useDataStoreShaderSpecialization_ = !sscv.constants.empty();
85 }
86 psoHandle_ = renderNodeContextMgr.GetPsoManager().GetComputePsoHandle(shader_, pipelineLayout_, {});
87
88 {
89 const DescriptorCounts dc = renderNodeUtil.GetDescriptorCounts(pipelineLayout_);
90 renderNodeContextMgr.GetDescriptorSetManager().ResetAndReserve(dc);
91 }
92
93 pipelineDescriptorSetBinder_ = renderNodeUtil.CreatePipelineDescriptorSetBinder(pipelineLayout_);
94 renderNodeUtil.BindResourcesToBinder(inputResources_, *pipelineDescriptorSetBinder_);
95
96 useDataStorePushConstant_ = (pipelineLayout_.pushConstant.byteSize > 0) &&
97 (!jsonInputs_.renderDataStore.dataStoreName.empty()) &&
98 (!jsonInputs_.renderDataStore.configurationName.empty());
99 }
100
PreExecuteFrame()101 void RenderNodeComputeGeneric::PreExecuteFrame()
102 {
103 // re-create needed gpu resources
104 }
105
ExecuteFrame(IRenderCommandList & cmdList)106 void RenderNodeComputeGeneric::ExecuteFrame(IRenderCommandList& cmdList)
107 {
108 const auto& renderNodeUtil = renderNodeContextMgr_->GetRenderNodeUtil();
109 if (jsonInputs_.hasChangeableResourceHandles) {
110 inputResources_ = renderNodeUtil.CreateInputResources(jsonInputs_.resources);
111 renderNodeUtil.BindResourcesToBinder(inputResources_, *pipelineDescriptorSetBinder_);
112 }
113 {
114 const auto setIndices = pipelineDescriptorSetBinder_->GetSetIndices();
115 for (auto refIndex : setIndices) {
116 const auto descHandle = pipelineDescriptorSetBinder_->GetDescriptorSetHandle(refIndex);
117 const auto bindings = pipelineDescriptorSetBinder_->GetDescriptorSetLayoutBindingResources(refIndex);
118 cmdList.UpdateDescriptorSet(descHandle, bindings);
119 }
120 #if (RENDER_VALIDATION_ENABLED == 1)
121 if (!pipelineDescriptorSetBinder_->GetPipelineDescriptorSetLayoutBindingValidity()) {
122 PLUGIN_LOG_E(
123 "RenderNodeComputeGeneric: bindings missing (RN: %s)", renderNodeContextMgr_->GetName().data());
124 }
125 #endif
126 }
127
128 const RenderHandle psoHandle = GetPsoHandle(*renderNodeContextMgr_);
129 cmdList.BindPipeline(psoHandle);
130
131 // bind all sets
132 {
133 const auto descHandles = pipelineDescriptorSetBinder_->GetDescriptorSetHandles();
134 cmdList.BindDescriptorSets(0, descHandles);
135 }
136
137 // push constants
138 if (useDataStorePushConstant_) {
139 const auto& renderDataStoreMgr = renderNodeContextMgr_->GetRenderDataStoreManager();
140 const auto dataStore = static_cast<IRenderDataStorePod const*>(
141 renderDataStoreMgr.GetRenderDataStore(jsonInputs_.renderDataStore.dataStoreName.c_str()));
142 if (dataStore) {
143 const auto dataView = dataStore->Get(jsonInputs_.renderDataStore.configurationName);
144 if (!dataView.empty()) {
145 cmdList.PushConstant(pipelineLayout_.pushConstant, dataView.data());
146 }
147 }
148 }
149
150 cmdList.Dispatch((targetSize_.x + threadGroupSize_.x - 1u) / threadGroupSize_.x,
151 (targetSize_.y + threadGroupSize_.y - 1u) / threadGroupSize_.y,
152 (targetSize_.z + threadGroupSize_.z - 1u) / threadGroupSize_.z);
153 }
154
GetPsoHandle(IRenderNodeContextManager & renderNodeContextMgr)155 RenderHandle RenderNodeComputeGeneric::GetPsoHandle(IRenderNodeContextManager& renderNodeContextMgr)
156 {
157 if (useDataStoreShaderSpecialization_) {
158 const auto& renderDataStoreMgr = renderNodeContextMgr.GetRenderDataStoreManager();
159 const auto dataStore = static_cast<IRenderDataStorePod const*>(
160 renderDataStoreMgr.GetRenderDataStore(jsonInputs_.renderDataStoreSpecialization.dataStoreName.c_str()));
161 if (dataStore) {
162 const auto dataView = dataStore->Get(jsonInputs_.renderDataStoreSpecialization.configurationName);
163 if (dataView.data() && (dataView.size_bytes() == sizeof(ShaderSpecializationRenderPod))) {
164 const auto* spec = reinterpret_cast<const ShaderSpecializationRenderPod*>(dataView.data());
165 bool valuesChanged = false;
166 const auto specializationCount = Math::min(
167 ShaderSpecializationRenderPod::MAX_SPECIALIZATION_CONSTANT_COUNT,
168 Math::min((uint32_t)shaderSpecializationData_.constants.size(), spec->specializationConstantCount));
169 const auto constantsView = array_view(shaderSpecializationData_.constants.data(), specializationCount);
170 for (const auto& ref : constantsView) {
171 const uint32_t constantId = ref.offset / sizeof(uint32_t);
172 const uint32_t specId = ref.id;
173 if (specId < ShaderSpecializationRenderPod::MAX_SPECIALIZATION_CONSTANT_COUNT) {
174 if (shaderSpecializationData_.data[constantId] != spec->specializationFlags[specId].value) {
175 shaderSpecializationData_.data[constantId] = spec->specializationFlags[specId].value;
176 valuesChanged = true;
177 }
178 }
179 }
180 if (valuesChanged) {
181 const ShaderSpecializationConstantDataView specialization {
182 constantsView,
183 { shaderSpecializationData_.data.data(), specializationCount },
184 };
185 psoHandle_ = renderNodeContextMgr.GetPsoManager().GetComputePsoHandle(
186 shader_, pipelineLayout_, specialization);
187 }
188 } else {
189 const string logName = "RenderNodeComputeGeneric_ShaderSpecialization" +
190 string(jsonInputs_.renderDataStoreSpecialization.configurationName);
191 PLUGIN_LOG_ONCE_E(logName.c_str(),
192 "RenderNodeComputeGeneric shader specilization render data store size mismatch, name: %s, "
193 "size:%u, podsize%u",
194 jsonInputs_.renderDataStoreSpecialization.configurationName.c_str(),
195 static_cast<uint32_t>(sizeof(ShaderSpecializationRenderPod)),
196 static_cast<uint32_t>(dataView.size_bytes()));
197 }
198 }
199 }
200 return psoHandle_;
201 }
202
ParseRenderNodeInputs()203 void RenderNodeComputeGeneric::ParseRenderNodeInputs()
204 {
205 const IRenderNodeParserUtil& parserUtil = renderNodeContextMgr_->GetRenderNodeParserUtil();
206 const auto jsonVal = renderNodeContextMgr_->GetNodeJson();
207 jsonInputs_.resources = parserUtil.GetInputResources(jsonVal, "resources");
208 jsonInputs_.renderDataStore = parserUtil.GetRenderDataStore(jsonVal, "renderDataStore");
209 jsonInputs_.renderDataStoreSpecialization =
210 parserUtil.GetRenderDataStore(jsonVal, "renderDataStoreShaderSpecialization");
211
212 const auto shaderName = parserUtil.GetStringValue(jsonVal, "shader");
213 const IRenderNodeShaderManager& shaderMgr = renderNodeContextMgr_->GetShaderManager();
214 shader_ = shaderMgr.GetShaderHandle(shaderName);
215
216 const auto& renderNodeUtil = renderNodeContextMgr_->GetRenderNodeUtil();
217 inputResources_ = renderNodeUtil.CreateInputResources(jsonInputs_.resources);
218 jsonInputs_.hasChangeableResourceHandles = renderNodeUtil.HasChangeableResources(jsonInputs_.resources);
219 }
220
221 // for plugin / factory interface
Create()222 IRenderNode* RenderNodeComputeGeneric::Create()
223 {
224 return new RenderNodeComputeGeneric();
225 }
226
Destroy(IRenderNode * instance)227 void RenderNodeComputeGeneric::Destroy(IRenderNode* instance)
228 {
229 delete static_cast<RenderNodeComputeGeneric*>(instance);
230 }
231 RENDER_END_NAMESPACE()
232