1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "render_node_compute_generic.h"
17
18 #include <base/math/mathf.h>
19 #include <base/math/vector.h>
20 #include <render/datastore/intf_render_data_store_manager.h>
21 #include <render/datastore/intf_render_data_store_pod.h>
22 #include <render/device/intf_gpu_resource_manager.h>
23 #include <render/device/intf_shader_manager.h>
24 #include <render/namespace.h>
25 #include <render/nodecontext/intf_node_context_descriptor_set_manager.h>
26 #include <render/nodecontext/intf_node_context_pso_manager.h>
27 #include <render/nodecontext/intf_pipeline_descriptor_set_binder.h>
28 #include <render/nodecontext/intf_render_command_list.h>
29 #include <render/nodecontext/intf_render_node_context_manager.h>
30 #include <render/nodecontext/intf_render_node_parser_util.h>
31 #include <render/nodecontext/intf_render_node_util.h>
32
33 #include "util/log.h"
34
35 using namespace BASE_NS;
36
37 RENDER_BEGIN_NAMESPACE()
38 namespace {
39 struct DispatchResources {
40 RenderHandle buffer {};
41 RenderHandle image {};
42 };
43
GetDispatchResources(const RenderNodeHandles::InputResources & ir)44 DispatchResources GetDispatchResources(const RenderNodeHandles::InputResources& ir)
45 {
46 DispatchResources dr;
47 if (!ir.customInputBuffers.empty()) {
48 dr.buffer = ir.customInputBuffers[0].handle;
49 }
50 if (!ir.customInputImages.empty()) {
51 dr.image = ir.customInputImages[0].handle;
52 }
53 return dr;
54 }
55 } // namespace
56
InitNode(IRenderNodeContextManager & renderNodeContextMgr)57 void RenderNodeComputeGeneric::InitNode(IRenderNodeContextManager& renderNodeContextMgr)
58 {
59 renderNodeContextMgr_ = &renderNodeContextMgr;
60 ParseRenderNodeInputs();
61
62 useDataStoreShaderSpecialization_ = !jsonInputs_.renderDataStoreSpecialization.dataStoreName.empty();
63
64 auto& shaderMgr = renderNodeContextMgr.GetShaderManager();
65 const auto& renderNodeUtil = renderNodeContextMgr.GetRenderNodeUtil();
66 if (RenderHandleUtil::GetHandleType(pipelineData_.sd.shader) != RenderHandleType::COMPUTE_SHADER_STATE_OBJECT) {
67 PLUGIN_LOG_E("RenderNodeComputeGeneric needs a valid compute shader handle");
68 }
69 pipelineData_.sd = shaderMgr.GetShaderDataByShaderHandle(pipelineData_.sd.shader);
70 threadGroupSize_ = shaderMgr.GetReflectionThreadGroupSize(pipelineData_.sd.shader);
71
72 if (dispatchResources_.customInputBuffers.empty() && dispatchResources_.customInputImages.empty()) {
73 PLUGIN_LOG_W("RenderNodeComputeGeneric: dispatchResources (GPU buffer or GPU image) needed");
74 }
75
76 if (useDataStoreShaderSpecialization_) {
77 const ShaderSpecializationConstantView sscv = shaderMgr.GetReflectionSpecialization(pipelineData_.sd.shader);
78 shaderSpecializationData_.constants.resize(sscv.constants.size());
79 shaderSpecializationData_.data.resize(sscv.constants.size());
80 for (size_t idx = 0; idx < shaderSpecializationData_.constants.size(); ++idx) {
81 shaderSpecializationData_.constants[idx] = sscv.constants[idx];
82 shaderSpecializationData_.data[idx] = ~0u;
83 }
84 useDataStoreShaderSpecialization_ = !sscv.constants.empty();
85 }
86 pipelineData_.pso = renderNodeContextMgr.GetPsoManager().GetComputePsoHandle(
87 pipelineData_.sd.shader, pipelineData_.sd.pipelineLayoutData, {});
88
89 {
90 const DescriptorCounts dc = renderNodeUtil.GetDescriptorCounts(pipelineData_.sd.pipelineLayoutData);
91 renderNodeContextMgr.GetDescriptorSetManager().ResetAndReserve(dc);
92 }
93
94 pipelineDescriptorSetBinder_ =
95 renderNodeUtil.CreatePipelineDescriptorSetBinder(pipelineData_.sd.pipelineLayoutData);
96 renderNodeUtil.BindResourcesToBinder(inputResources_, *pipelineDescriptorSetBinder_);
97
98 useDataStorePushConstant_ = (pipelineData_.sd.pipelineLayoutData.pushConstant.byteSize > 0) &&
99 (!jsonInputs_.renderDataStore.dataStoreName.empty()) &&
100 (!jsonInputs_.renderDataStore.configurationName.empty());
101 }
102
PreExecuteFrame()103 void RenderNodeComputeGeneric::PreExecuteFrame()
104 {
105 // re-create needed gpu resources
106 }
107
ExecuteFrame(IRenderCommandList & cmdList)108 void RenderNodeComputeGeneric::ExecuteFrame(IRenderCommandList& cmdList)
109 {
110 if (!RenderHandleUtil::IsValid(pipelineData_.sd.shader)) {
111 return; // invalid shader
112 }
113
114 const auto& renderNodeUtil = renderNodeContextMgr_->GetRenderNodeUtil();
115 if (jsonInputs_.hasChangeableResourceHandles) {
116 inputResources_ = renderNodeUtil.CreateInputResources(jsonInputs_.resources);
117 renderNodeUtil.BindResourcesToBinder(inputResources_, *pipelineDescriptorSetBinder_);
118 }
119 if (jsonInputs_.hasChangeableDispatchHandles) {
120 dispatchResources_ = renderNodeUtil.CreateInputResources(jsonInputs_.dispatchResources);
121 }
122 const DispatchResources dr = GetDispatchResources(dispatchResources_);
123 if ((!RenderHandleUtil::IsValid(dr.buffer)) && (!RenderHandleUtil::IsValid(dr.image))) {
124 #if (RENDER_VALIDATION_ENABLED == 1)
125 PLUGIN_LOG_ONCE_W(renderNodeContextMgr_->GetName() + "_no_dr",
126 "RENDER_VALIDATION: RN: %s, no valid dispatch resource", renderNodeContextMgr_->GetName().data());
127 #endif
128 return; // no way to evaluate dispatch size
129 }
130 const uint32_t firstSetIndex = pipelineDescriptorSetBinder_->GetFirstSet();
131 {
132 const auto setIndices = pipelineDescriptorSetBinder_->GetSetIndices();
133 for (auto refIndex : setIndices) {
134 const auto descHandle = pipelineDescriptorSetBinder_->GetDescriptorSetHandle(refIndex);
135 const auto bindings = pipelineDescriptorSetBinder_->GetDescriptorSetLayoutBindingResources(refIndex);
136 cmdList.UpdateDescriptorSet(descHandle, bindings);
137 }
138 #if (RENDER_VALIDATION_ENABLED == 1)
139 if (!pipelineDescriptorSetBinder_->GetPipelineDescriptorSetLayoutBindingValidity()) {
140 PLUGIN_LOG_ONCE_E(renderNodeContextMgr_->GetName() + "_bindings_missing",
141 "RENDER_VALIDATION: RenderNodeComputeGeneric: bindings missing (RN: %s)",
142 renderNodeContextMgr_->GetName().data());
143 }
144 #endif
145 }
146
147 const RenderHandle psoHandle = GetPsoHandle(*renderNodeContextMgr_);
148 cmdList.BindPipeline(psoHandle);
149
150 // bind all sets
151 {
152 const auto descHandles = pipelineDescriptorSetBinder_->GetDescriptorSetHandles();
153 cmdList.BindDescriptorSets(firstSetIndex, descHandles);
154 }
155
156 // push constants
157 if (useDataStorePushConstant_) {
158 const auto& renderDataStoreMgr = renderNodeContextMgr_->GetRenderDataStoreManager();
159 const auto dataStore = static_cast<const IRenderDataStorePod*>(
160 renderDataStoreMgr.GetRenderDataStore(jsonInputs_.renderDataStore.dataStoreName.c_str()));
161 if (dataStore) {
162 const auto dataView = dataStore->Get(jsonInputs_.renderDataStore.configurationName);
163 if (!dataView.empty()) {
164 cmdList.PushConstant(pipelineData_.sd.pipelineLayoutData.pushConstant, dataView.data());
165 }
166 }
167 }
168
169 if (RenderHandleUtil::IsValid(dr.buffer)) {
170 cmdList.DispatchIndirect(dr.buffer, 0);
171 } else if (RenderHandleUtil::IsValid(dr.image)) {
172 const IRenderNodeGpuResourceManager& gpuResourceMgr = renderNodeContextMgr_->GetGpuResourceManager();
173 const GpuImageDesc desc = gpuResourceMgr.GetImageDescriptor(dr.image);
174 const Math::UVec3 targetSize = { desc.width, desc.height, desc.depth };
175 cmdList.Dispatch((targetSize.x + threadGroupSize_.x - 1u) / threadGroupSize_.x,
176 (targetSize.y + threadGroupSize_.y - 1u) / threadGroupSize_.y,
177 (targetSize.z + threadGroupSize_.z - 1u) / threadGroupSize_.z);
178 }
179 }
180
GetPsoHandle(IRenderNodeContextManager & renderNodeContextMgr)181 RenderHandle RenderNodeComputeGeneric::GetPsoHandle(IRenderNodeContextManager& renderNodeContextMgr)
182 {
183 if (!useDataStoreShaderSpecialization_) {
184 return pipelineData_.pso; // early out
185 }
186 const auto& renderDataStoreMgr = renderNodeContextMgr.GetRenderDataStoreManager();
187 const auto dataStore = static_cast<const IRenderDataStorePod*>(
188 renderDataStoreMgr.GetRenderDataStore(jsonInputs_.renderDataStoreSpecialization.dataStoreName.c_str()));
189 if (!dataStore) {
190 return pipelineData_.pso; // early out
191 }
192 const auto dataView = dataStore->Get(jsonInputs_.renderDataStoreSpecialization.configurationName);
193 if (dataView.data() && (dataView.size_bytes() == sizeof(ShaderSpecializationRenderPod))) {
194 const auto* spec = reinterpret_cast<const ShaderSpecializationRenderPod*>(dataView.data());
195 bool valuesChanged = false;
196 const auto specializationCount = Math::min(ShaderSpecializationRenderPod::MAX_SPECIALIZATION_CONSTANT_COUNT,
197 Math::min((uint32_t)shaderSpecializationData_.constants.size(), spec->specializationConstantCount));
198 const auto constantsView = array_view(shaderSpecializationData_.constants.data(), specializationCount);
199 for (const auto& ref : constantsView) {
200 const uint32_t constantId = ref.offset / sizeof(uint32_t);
201 const uint32_t specId = ref.id;
202 if ((specId < ShaderSpecializationRenderPod::MAX_SPECIALIZATION_CONSTANT_COUNT) &&
203 (shaderSpecializationData_.data[constantId] != spec->specializationFlags[specId].value)) {
204 shaderSpecializationData_.data[constantId] = spec->specializationFlags[specId].value;
205 valuesChanged = true;
206 }
207 }
208 if (valuesChanged) {
209 const ShaderSpecializationConstantDataView specialization {
210 constantsView,
211 { shaderSpecializationData_.data.data(), specializationCount },
212 };
213 pipelineData_.pso = renderNodeContextMgr.GetPsoManager().GetComputePsoHandle(
214 pipelineData_.sd.shader, pipelineData_.sd.pipelineLayout, specialization);
215 }
216 } else {
217 #if (RENDER_VALIDATION_ENABLED == 1)
218 const string logName = "RenderNodeComputeGeneric_ShaderSpecialization" +
219 string(jsonInputs_.renderDataStoreSpecialization.configurationName);
220 PLUGIN_LOG_ONCE_E(logName.c_str(),
221 "RENDER_VALIDATION: RenderNodeComputeGeneric shader specilization render data store size mismatch, "
222 "name: %s, size:%u, podsize%u",
223 jsonInputs_.renderDataStoreSpecialization.configurationName.c_str(),
224 static_cast<uint32_t>(sizeof(ShaderSpecializationRenderPod)), static_cast<uint32_t>(dataView.size_bytes()));
225 #endif
226 }
227 return pipelineData_.pso;
228 }
229
ParseRenderNodeInputs()230 void RenderNodeComputeGeneric::ParseRenderNodeInputs()
231 {
232 const IRenderNodeParserUtil& parserUtil = renderNodeContextMgr_->GetRenderNodeParserUtil();
233 const auto jsonVal = renderNodeContextMgr_->GetNodeJson();
234 jsonInputs_.resources = parserUtil.GetInputResources(jsonVal, "resources");
235 jsonInputs_.dispatchResources = parserUtil.GetInputResources(jsonVal, "dispatchResources");
236 jsonInputs_.renderDataStore = parserUtil.GetRenderDataStore(jsonVal, "renderDataStore");
237 jsonInputs_.renderDataStoreSpecialization =
238 parserUtil.GetRenderDataStore(jsonVal, "renderDataStoreShaderSpecialization");
239
240 const auto shaderName = parserUtil.GetStringValue(jsonVal, "shader");
241 const IRenderNodeShaderManager& shaderMgr = renderNodeContextMgr_->GetShaderManager();
242 pipelineData_.sd.shader = shaderMgr.GetShaderHandle(shaderName);
243
244 const auto& renderNodeUtil = renderNodeContextMgr_->GetRenderNodeUtil();
245 inputResources_ = renderNodeUtil.CreateInputResources(jsonInputs_.resources);
246 dispatchResources_ = renderNodeUtil.CreateInputResources(jsonInputs_.dispatchResources);
247 jsonInputs_.hasChangeableResourceHandles = renderNodeUtil.HasChangeableResources(jsonInputs_.resources);
248 jsonInputs_.hasChangeableDispatchHandles = renderNodeUtil.HasChangeableResources(jsonInputs_.dispatchResources);
249 }
250
251 // for plugin / factory interface
Create()252 IRenderNode* RenderNodeComputeGeneric::Create()
253 {
254 return new RenderNodeComputeGeneric();
255 }
256
Destroy(IRenderNode * instance)257 void RenderNodeComputeGeneric::Destroy(IRenderNode* instance)
258 {
259 delete static_cast<RenderNodeComputeGeneric*>(instance);
260 }
261 RENDER_END_NAMESPACE()
262