1 // Copyright 2021 The Dawn Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #include "dawn_native/IndirectDrawValidationEncoder.h" 16 17 #include "common/Constants.h" 18 #include "common/Math.h" 19 #include "dawn_native/BindGroup.h" 20 #include "dawn_native/BindGroupLayout.h" 21 #include "dawn_native/CommandEncoder.h" 22 #include "dawn_native/ComputePassEncoder.h" 23 #include "dawn_native/ComputePipeline.h" 24 #include "dawn_native/Device.h" 25 #include "dawn_native/InternalPipelineStore.h" 26 #include "dawn_native/Queue.h" 27 #include "dawn_native/utils/WGPUHelpers.h" 28 29 #include <cstdlib> 30 #include <limits> 31 32 namespace dawn_native { 33 34 namespace { 35 // NOTE: This must match the workgroup_size attribute on the compute entry point below. 36 constexpr uint64_t kWorkgroupSize = 64; 37 38 // Equivalent to the BatchInfo struct defined in the shader below. 39 struct BatchInfo { 40 uint64_t numIndexBufferElements; 41 uint32_t numDraws; 42 uint32_t padding; 43 }; 44 45 // TODO(https://crbug.com/dawn/1108): Propagate validation feedback from this shader in 46 // various failure modes. 47 static const char sRenderValidationShaderSource[] = R"( 48 let kNumIndirectParamsPerDrawCall = 5u; 49 50 let kIndexCountEntry = 0u; 51 let kInstanceCountEntry = 1u; 52 let kFirstIndexEntry = 2u; 53 let kBaseVertexEntry = 3u; 54 let kFirstInstanceEntry = 4u; 55 56 [[block]] struct BatchInfo { 57 numIndexBufferElementsLow: u32; 58 numIndexBufferElementsHigh: u32; 59 numDraws: u32; 60 padding: u32; 61 indirectOffsets: array<u32>; 62 }; 63 64 [[block]] struct IndirectParams { 65 data: array<u32>; 66 }; 67 68 [[group(0), binding(0)]] var<storage, read> batch: BatchInfo; 69 [[group(0), binding(1)]] var<storage, read_write> clientParams: IndirectParams; 70 [[group(0), binding(2)]] var<storage, write> validatedParams: IndirectParams; 71 72 fn fail(drawIndex: u32) { 73 let index = drawIndex * kNumIndirectParamsPerDrawCall; 74 validatedParams.data[index + kIndexCountEntry] = 0u; 75 validatedParams.data[index + kInstanceCountEntry] = 0u; 76 validatedParams.data[index + kFirstIndexEntry] = 0u; 77 validatedParams.data[index + kBaseVertexEntry] = 0u; 78 validatedParams.data[index + kFirstInstanceEntry] = 0u; 79 } 80 81 fn pass(drawIndex: u32) { 82 let vIndex = drawIndex * kNumIndirectParamsPerDrawCall; 83 let cIndex = batch.indirectOffsets[drawIndex]; 84 validatedParams.data[vIndex + kIndexCountEntry] = 85 clientParams.data[cIndex + kIndexCountEntry]; 86 validatedParams.data[vIndex + kInstanceCountEntry] = 87 clientParams.data[cIndex + kInstanceCountEntry]; 88 validatedParams.data[vIndex + kFirstIndexEntry] = 89 clientParams.data[cIndex + kFirstIndexEntry]; 90 validatedParams.data[vIndex + kBaseVertexEntry] = 91 clientParams.data[cIndex + kBaseVertexEntry]; 92 validatedParams.data[vIndex + kFirstInstanceEntry] = 93 clientParams.data[cIndex + kFirstInstanceEntry]; 94 } 95 96 [[stage(compute), workgroup_size(64, 1, 1)]] 97 fn main([[builtin(global_invocation_id)]] id : vec3<u32>) { 98 if (id.x >= batch.numDraws) { 99 return; 100 } 101 102 let clientIndex = batch.indirectOffsets[id.x]; 103 let firstInstance = clientParams.data[clientIndex + kFirstInstanceEntry]; 104 if (firstInstance != 0u) { 105 fail(id.x); 106 return; 107 } 108 109 if (batch.numIndexBufferElementsHigh >= 2u) { 110 // firstIndex and indexCount are both u32. The maximum possible sum of these 111 // values is 0x1fffffffe, which is less than 0x200000000. Nothing to validate. 112 pass(id.x); 113 return; 114 } 115 116 let firstIndex = clientParams.data[clientIndex + kFirstIndexEntry]; 117 if (batch.numIndexBufferElementsHigh == 0u && 118 batch.numIndexBufferElementsLow < firstIndex) { 119 fail(id.x); 120 return; 121 } 122 123 // Note that this subtraction may underflow, but only when 124 // numIndexBufferElementsHigh is 1u. The result is still correct in that case. 125 let maxIndexCount = batch.numIndexBufferElementsLow - firstIndex; 126 let indexCount = clientParams.data[clientIndex + kIndexCountEntry]; 127 if (indexCount > maxIndexCount) { 128 fail(id.x); 129 return; 130 } 131 pass(id.x); 132 } 133 )"; 134 GetOrCreateRenderValidationPipeline(DeviceBase * device)135 ResultOrError<ComputePipelineBase*> GetOrCreateRenderValidationPipeline( 136 DeviceBase* device) { 137 InternalPipelineStore* store = device->GetInternalPipelineStore(); 138 139 if (store->renderValidationPipeline == nullptr) { 140 // Create compute shader module if not cached before. 141 if (store->renderValidationShader == nullptr) { 142 DAWN_TRY_ASSIGN( 143 store->renderValidationShader, 144 utils::CreateShaderModule(device, sRenderValidationShaderSource)); 145 } 146 147 Ref<BindGroupLayoutBase> bindGroupLayout; 148 DAWN_TRY_ASSIGN( 149 bindGroupLayout, 150 utils::MakeBindGroupLayout( 151 device, 152 { 153 {0, wgpu::ShaderStage::Compute, 154 wgpu::BufferBindingType::ReadOnlyStorage}, 155 {1, wgpu::ShaderStage::Compute, kInternalStorageBufferBinding}, 156 {2, wgpu::ShaderStage::Compute, wgpu::BufferBindingType::Storage}, 157 }, 158 /* allowInternalBinding */ true)); 159 160 Ref<PipelineLayoutBase> pipelineLayout; 161 DAWN_TRY_ASSIGN(pipelineLayout, 162 utils::MakeBasicPipelineLayout(device, bindGroupLayout)); 163 164 ComputePipelineDescriptor computePipelineDescriptor = {}; 165 computePipelineDescriptor.layout = pipelineLayout.Get(); 166 computePipelineDescriptor.compute.module = store->renderValidationShader.Get(); 167 computePipelineDescriptor.compute.entryPoint = "main"; 168 169 DAWN_TRY_ASSIGN(store->renderValidationPipeline, 170 device->CreateComputePipeline(&computePipelineDescriptor)); 171 } 172 173 return store->renderValidationPipeline.Get(); 174 } 175 GetBatchDataSize(uint32_t numDraws)176 size_t GetBatchDataSize(uint32_t numDraws) { 177 return sizeof(BatchInfo) + numDraws * sizeof(uint32_t); 178 } 179 180 } // namespace 181 ComputeMaxDrawCallsPerIndirectValidationBatch(const CombinedLimits & limits)182 uint32_t ComputeMaxDrawCallsPerIndirectValidationBatch(const CombinedLimits& limits) { 183 const uint64_t batchDrawCallLimitByDispatchSize = 184 static_cast<uint64_t>(limits.v1.maxComputeWorkgroupsPerDimension) * kWorkgroupSize; 185 const uint64_t batchDrawCallLimitByStorageBindingSize = 186 (limits.v1.maxStorageBufferBindingSize - sizeof(BatchInfo)) / sizeof(uint32_t); 187 return static_cast<uint32_t>( 188 std::min({batchDrawCallLimitByDispatchSize, batchDrawCallLimitByStorageBindingSize, 189 uint64_t(std::numeric_limits<uint32_t>::max())})); 190 } 191 EncodeIndirectDrawValidationCommands(DeviceBase * device,CommandEncoder * commandEncoder,RenderPassResourceUsageTracker * usageTracker,IndirectDrawMetadata * indirectDrawMetadata)192 MaybeError EncodeIndirectDrawValidationCommands(DeviceBase* device, 193 CommandEncoder* commandEncoder, 194 RenderPassResourceUsageTracker* usageTracker, 195 IndirectDrawMetadata* indirectDrawMetadata) { 196 struct Batch { 197 const IndirectDrawMetadata::IndexedIndirectValidationBatch* metadata; 198 uint64_t numIndexBufferElements; 199 uint64_t dataBufferOffset; 200 uint64_t dataSize; 201 uint64_t clientIndirectOffset; 202 uint64_t clientIndirectSize; 203 uint64_t validatedParamsOffset; 204 uint64_t validatedParamsSize; 205 BatchInfo* batchInfo; 206 }; 207 208 struct Pass { 209 BufferBase* clientIndirectBuffer; 210 uint64_t validatedParamsSize = 0; 211 uint64_t batchDataSize = 0; 212 std::unique_ptr<void, void (*)(void*)> batchData{nullptr, std::free}; 213 std::vector<Batch> batches; 214 }; 215 216 // First stage is grouping all batches into passes. We try to pack as many batches into a 217 // single pass as possible. Batches can be grouped together as long as they're validating 218 // data from the same indirect buffer, but they may still be split into multiple passes if 219 // the number of draw calls in a pass would exceed some (very high) upper bound. 220 size_t validatedParamsSize = 0; 221 std::vector<Pass> passes; 222 IndirectDrawMetadata::IndexedIndirectBufferValidationInfoMap& bufferInfoMap = 223 *indirectDrawMetadata->GetIndexedIndirectBufferValidationInfo(); 224 if (bufferInfoMap.empty()) { 225 return {}; 226 } 227 228 const uint32_t maxStorageBufferBindingSize = 229 device->GetLimits().v1.maxStorageBufferBindingSize; 230 const uint32_t minStorageBufferOffsetAlignment = 231 device->GetLimits().v1.minStorageBufferOffsetAlignment; 232 233 for (auto& entry : bufferInfoMap) { 234 const IndirectDrawMetadata::IndexedIndirectConfig& config = entry.first; 235 BufferBase* clientIndirectBuffer = config.first; 236 for (const IndirectDrawMetadata::IndexedIndirectValidationBatch& batch : 237 entry.second.GetBatches()) { 238 const uint64_t minOffsetFromAlignedBoundary = 239 batch.minOffset % minStorageBufferOffsetAlignment; 240 const uint64_t minOffsetAlignedDown = 241 batch.minOffset - minOffsetFromAlignedBoundary; 242 243 Batch newBatch; 244 newBatch.metadata = &batch; 245 newBatch.numIndexBufferElements = config.second; 246 newBatch.dataSize = GetBatchDataSize(batch.draws.size()); 247 newBatch.clientIndirectOffset = minOffsetAlignedDown; 248 newBatch.clientIndirectSize = 249 batch.maxOffset + kDrawIndexedIndirectSize - minOffsetAlignedDown; 250 251 newBatch.validatedParamsSize = batch.draws.size() * kDrawIndexedIndirectSize; 252 newBatch.validatedParamsOffset = 253 Align(validatedParamsSize, minStorageBufferOffsetAlignment); 254 validatedParamsSize = newBatch.validatedParamsOffset + newBatch.validatedParamsSize; 255 if (validatedParamsSize > maxStorageBufferBindingSize) { 256 return DAWN_INTERNAL_ERROR("Too many drawIndexedIndirect calls to validate"); 257 } 258 259 Pass* currentPass = passes.empty() ? nullptr : &passes.back(); 260 if (currentPass && currentPass->clientIndirectBuffer == clientIndirectBuffer) { 261 uint64_t nextBatchDataOffset = 262 Align(currentPass->batchDataSize, minStorageBufferOffsetAlignment); 263 uint64_t newPassBatchDataSize = nextBatchDataOffset + newBatch.dataSize; 264 if (newPassBatchDataSize <= maxStorageBufferBindingSize) { 265 // We can fit this batch in the current pass. 266 newBatch.dataBufferOffset = nextBatchDataOffset; 267 currentPass->batchDataSize = newPassBatchDataSize; 268 currentPass->batches.push_back(newBatch); 269 continue; 270 } 271 } 272 273 // We need to start a new pass for this batch. 274 newBatch.dataBufferOffset = 0; 275 276 Pass newPass; 277 newPass.clientIndirectBuffer = clientIndirectBuffer; 278 newPass.batchDataSize = newBatch.dataSize; 279 newPass.batches.push_back(newBatch); 280 passes.push_back(std::move(newPass)); 281 } 282 } 283 284 auto* const store = device->GetInternalPipelineStore(); 285 ScratchBuffer& validatedParamsBuffer = store->scratchIndirectStorage; 286 ScratchBuffer& batchDataBuffer = store->scratchStorage; 287 288 uint64_t requiredBatchDataBufferSize = 0; 289 for (const Pass& pass : passes) { 290 requiredBatchDataBufferSize = std::max(requiredBatchDataBufferSize, pass.batchDataSize); 291 } 292 DAWN_TRY(batchDataBuffer.EnsureCapacity(requiredBatchDataBufferSize)); 293 usageTracker->BufferUsedAs(batchDataBuffer.GetBuffer(), wgpu::BufferUsage::Storage); 294 295 DAWN_TRY(validatedParamsBuffer.EnsureCapacity(validatedParamsSize)); 296 usageTracker->BufferUsedAs(validatedParamsBuffer.GetBuffer(), wgpu::BufferUsage::Indirect); 297 298 // Now we allocate and populate host-side batch data to be copied to the GPU. 299 for (Pass& pass : passes) { 300 // We use std::malloc here because it guarantees maximal scalar alignment. 301 pass.batchData = {std::malloc(pass.batchDataSize), std::free}; 302 memset(pass.batchData.get(), 0, pass.batchDataSize); 303 uint8_t* batchData = static_cast<uint8_t*>(pass.batchData.get()); 304 for (Batch& batch : pass.batches) { 305 batch.batchInfo = new (&batchData[batch.dataBufferOffset]) BatchInfo(); 306 batch.batchInfo->numIndexBufferElements = batch.numIndexBufferElements; 307 batch.batchInfo->numDraws = static_cast<uint32_t>(batch.metadata->draws.size()); 308 309 uint32_t* indirectOffsets = reinterpret_cast<uint32_t*>(batch.batchInfo + 1); 310 uint64_t validatedParamsOffset = batch.validatedParamsOffset; 311 for (auto& draw : batch.metadata->draws) { 312 // The shader uses this to index an array of u32, hence the division by 4 bytes. 313 *indirectOffsets++ = static_cast<uint32_t>( 314 (draw.clientBufferOffset - batch.clientIndirectOffset) / 4); 315 316 draw.cmd->indirectBuffer = validatedParamsBuffer.GetBuffer(); 317 draw.cmd->indirectOffset = validatedParamsOffset; 318 319 validatedParamsOffset += kDrawIndexedIndirectSize; 320 } 321 } 322 } 323 324 ComputePipelineBase* pipeline; 325 DAWN_TRY_ASSIGN(pipeline, GetOrCreateRenderValidationPipeline(device)); 326 327 Ref<BindGroupLayoutBase> layout; 328 DAWN_TRY_ASSIGN(layout, pipeline->GetBindGroupLayout(0)); 329 330 BindGroupEntry bindings[3]; 331 BindGroupEntry& bufferDataBinding = bindings[0]; 332 bufferDataBinding.binding = 0; 333 bufferDataBinding.buffer = batchDataBuffer.GetBuffer(); 334 335 BindGroupEntry& clientIndirectBinding = bindings[1]; 336 clientIndirectBinding.binding = 1; 337 338 BindGroupEntry& validatedParamsBinding = bindings[2]; 339 validatedParamsBinding.binding = 2; 340 validatedParamsBinding.buffer = validatedParamsBuffer.GetBuffer(); 341 342 BindGroupDescriptor bindGroupDescriptor = {}; 343 bindGroupDescriptor.layout = layout.Get(); 344 bindGroupDescriptor.entryCount = 3; 345 bindGroupDescriptor.entries = bindings; 346 347 // Finally, we can now encode our validation passes. Each pass first does a single 348 // WriteBuffer to get batch data over to the GPU, followed by a single compute pass. The 349 // compute pass encodes a separate SetBindGroup and Dispatch command for each batch. 350 for (const Pass& pass : passes) { 351 commandEncoder->APIWriteBuffer(batchDataBuffer.GetBuffer(), 0, 352 static_cast<const uint8_t*>(pass.batchData.get()), 353 pass.batchDataSize); 354 355 // TODO(dawn:723): change to not use AcquireRef for reentrant object creation. 356 ComputePassDescriptor descriptor = {}; 357 Ref<ComputePassEncoder> passEncoder = 358 AcquireRef(commandEncoder->APIBeginComputePass(&descriptor)); 359 passEncoder->APISetPipeline(pipeline); 360 361 clientIndirectBinding.buffer = pass.clientIndirectBuffer; 362 363 for (const Batch& batch : pass.batches) { 364 bufferDataBinding.offset = batch.dataBufferOffset; 365 bufferDataBinding.size = batch.dataSize; 366 clientIndirectBinding.offset = batch.clientIndirectOffset; 367 clientIndirectBinding.size = batch.clientIndirectSize; 368 validatedParamsBinding.offset = batch.validatedParamsOffset; 369 validatedParamsBinding.size = batch.validatedParamsSize; 370 371 Ref<BindGroupBase> bindGroup; 372 DAWN_TRY_ASSIGN(bindGroup, device->CreateBindGroup(&bindGroupDescriptor)); 373 374 const uint32_t numDrawsRoundedUp = 375 (batch.batchInfo->numDraws + kWorkgroupSize - 1) / kWorkgroupSize; 376 passEncoder->APISetBindGroup(0, bindGroup.Get()); 377 passEncoder->APIDispatch(numDrawsRoundedUp); 378 } 379 380 passEncoder->APIEndPass(); 381 } 382 383 return {}; 384 } 385 386 } // namespace dawn_native 387