• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Dawn Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "dawn_native/IndirectDrawValidationEncoder.h"
16 
17 #include "common/Constants.h"
18 #include "common/Math.h"
19 #include "dawn_native/BindGroup.h"
20 #include "dawn_native/BindGroupLayout.h"
21 #include "dawn_native/CommandEncoder.h"
22 #include "dawn_native/ComputePassEncoder.h"
23 #include "dawn_native/ComputePipeline.h"
24 #include "dawn_native/Device.h"
25 #include "dawn_native/InternalPipelineStore.h"
26 #include "dawn_native/Queue.h"
27 #include "dawn_native/utils/WGPUHelpers.h"
28 
29 #include <cstdlib>
30 #include <limits>
31 
32 namespace dawn_native {
33 
34     namespace {
35         // NOTE: This must match the workgroup_size attribute on the compute entry point below.
36         constexpr uint64_t kWorkgroupSize = 64;
37 
38         // Equivalent to the BatchInfo struct defined in the shader below.
39         struct BatchInfo {
40             uint64_t numIndexBufferElements;
41             uint32_t numDraws;
42             uint32_t padding;
43         };
44 
45         // TODO(https://crbug.com/dawn/1108): Propagate validation feedback from this shader in
46         // various failure modes.
47         static const char sRenderValidationShaderSource[] = R"(
48             let kNumIndirectParamsPerDrawCall = 5u;
49 
50             let kIndexCountEntry = 0u;
51             let kInstanceCountEntry = 1u;
52             let kFirstIndexEntry = 2u;
53             let kBaseVertexEntry = 3u;
54             let kFirstInstanceEntry = 4u;
55 
56             [[block]] struct BatchInfo {
57                 numIndexBufferElementsLow: u32;
58                 numIndexBufferElementsHigh: u32;
59                 numDraws: u32;
60                 padding: u32;
61                 indirectOffsets: array<u32>;
62             };
63 
64             [[block]] struct IndirectParams {
65                 data: array<u32>;
66             };
67 
68             [[group(0), binding(0)]] var<storage, read> batch: BatchInfo;
69             [[group(0), binding(1)]] var<storage, read_write> clientParams: IndirectParams;
70             [[group(0), binding(2)]] var<storage, write> validatedParams: IndirectParams;
71 
72             fn fail(drawIndex: u32) {
73                 let index = drawIndex * kNumIndirectParamsPerDrawCall;
74                 validatedParams.data[index + kIndexCountEntry] = 0u;
75                 validatedParams.data[index + kInstanceCountEntry] = 0u;
76                 validatedParams.data[index + kFirstIndexEntry] = 0u;
77                 validatedParams.data[index + kBaseVertexEntry] = 0u;
78                 validatedParams.data[index + kFirstInstanceEntry] = 0u;
79             }
80 
81             fn pass(drawIndex: u32) {
82                 let vIndex = drawIndex * kNumIndirectParamsPerDrawCall;
83                 let cIndex = batch.indirectOffsets[drawIndex];
84                 validatedParams.data[vIndex + kIndexCountEntry] =
85                     clientParams.data[cIndex + kIndexCountEntry];
86                 validatedParams.data[vIndex + kInstanceCountEntry] =
87                     clientParams.data[cIndex + kInstanceCountEntry];
88                 validatedParams.data[vIndex + kFirstIndexEntry] =
89                     clientParams.data[cIndex + kFirstIndexEntry];
90                 validatedParams.data[vIndex + kBaseVertexEntry] =
91                     clientParams.data[cIndex + kBaseVertexEntry];
92                 validatedParams.data[vIndex + kFirstInstanceEntry] =
93                     clientParams.data[cIndex + kFirstInstanceEntry];
94             }
95 
96             [[stage(compute), workgroup_size(64, 1, 1)]]
97             fn main([[builtin(global_invocation_id)]] id : vec3<u32>) {
98                 if (id.x >= batch.numDraws) {
99                     return;
100                 }
101 
102                 let clientIndex = batch.indirectOffsets[id.x];
103                 let firstInstance = clientParams.data[clientIndex + kFirstInstanceEntry];
104                 if (firstInstance != 0u) {
105                     fail(id.x);
106                     return;
107                 }
108 
109                 if (batch.numIndexBufferElementsHigh >= 2u) {
110                     // firstIndex and indexCount are both u32. The maximum possible sum of these
111                     // values is 0x1fffffffe, which is less than 0x200000000. Nothing to validate.
112                     pass(id.x);
113                     return;
114                 }
115 
116                 let firstIndex = clientParams.data[clientIndex + kFirstIndexEntry];
117                 if (batch.numIndexBufferElementsHigh == 0u &&
118                     batch.numIndexBufferElementsLow < firstIndex) {
119                     fail(id.x);
120                     return;
121                 }
122 
123                 // Note that this subtraction may underflow, but only when
124                 // numIndexBufferElementsHigh is 1u. The result is still correct in that case.
125                 let maxIndexCount = batch.numIndexBufferElementsLow - firstIndex;
126                 let indexCount = clientParams.data[clientIndex + kIndexCountEntry];
127                 if (indexCount > maxIndexCount) {
128                     fail(id.x);
129                     return;
130                 }
131                 pass(id.x);
132             }
133         )";
134 
GetOrCreateRenderValidationPipeline(DeviceBase * device)135         ResultOrError<ComputePipelineBase*> GetOrCreateRenderValidationPipeline(
136             DeviceBase* device) {
137             InternalPipelineStore* store = device->GetInternalPipelineStore();
138 
139             if (store->renderValidationPipeline == nullptr) {
140                 // Create compute shader module if not cached before.
141                 if (store->renderValidationShader == nullptr) {
142                     DAWN_TRY_ASSIGN(
143                         store->renderValidationShader,
144                         utils::CreateShaderModule(device, sRenderValidationShaderSource));
145                 }
146 
147                 Ref<BindGroupLayoutBase> bindGroupLayout;
148                 DAWN_TRY_ASSIGN(
149                     bindGroupLayout,
150                     utils::MakeBindGroupLayout(
151                         device,
152                         {
153                             {0, wgpu::ShaderStage::Compute,
154                              wgpu::BufferBindingType::ReadOnlyStorage},
155                             {1, wgpu::ShaderStage::Compute, kInternalStorageBufferBinding},
156                             {2, wgpu::ShaderStage::Compute, wgpu::BufferBindingType::Storage},
157                         },
158                         /* allowInternalBinding */ true));
159 
160                 Ref<PipelineLayoutBase> pipelineLayout;
161                 DAWN_TRY_ASSIGN(pipelineLayout,
162                                 utils::MakeBasicPipelineLayout(device, bindGroupLayout));
163 
164                 ComputePipelineDescriptor computePipelineDescriptor = {};
165                 computePipelineDescriptor.layout = pipelineLayout.Get();
166                 computePipelineDescriptor.compute.module = store->renderValidationShader.Get();
167                 computePipelineDescriptor.compute.entryPoint = "main";
168 
169                 DAWN_TRY_ASSIGN(store->renderValidationPipeline,
170                                 device->CreateComputePipeline(&computePipelineDescriptor));
171             }
172 
173             return store->renderValidationPipeline.Get();
174         }
175 
GetBatchDataSize(uint32_t numDraws)176         size_t GetBatchDataSize(uint32_t numDraws) {
177             return sizeof(BatchInfo) + numDraws * sizeof(uint32_t);
178         }
179 
180     }  // namespace
181 
ComputeMaxDrawCallsPerIndirectValidationBatch(const CombinedLimits & limits)182     uint32_t ComputeMaxDrawCallsPerIndirectValidationBatch(const CombinedLimits& limits) {
183         const uint64_t batchDrawCallLimitByDispatchSize =
184             static_cast<uint64_t>(limits.v1.maxComputeWorkgroupsPerDimension) * kWorkgroupSize;
185         const uint64_t batchDrawCallLimitByStorageBindingSize =
186             (limits.v1.maxStorageBufferBindingSize - sizeof(BatchInfo)) / sizeof(uint32_t);
187         return static_cast<uint32_t>(
188             std::min({batchDrawCallLimitByDispatchSize, batchDrawCallLimitByStorageBindingSize,
189                       uint64_t(std::numeric_limits<uint32_t>::max())}));
190     }
191 
EncodeIndirectDrawValidationCommands(DeviceBase * device,CommandEncoder * commandEncoder,RenderPassResourceUsageTracker * usageTracker,IndirectDrawMetadata * indirectDrawMetadata)192     MaybeError EncodeIndirectDrawValidationCommands(DeviceBase* device,
193                                                     CommandEncoder* commandEncoder,
194                                                     RenderPassResourceUsageTracker* usageTracker,
195                                                     IndirectDrawMetadata* indirectDrawMetadata) {
196         struct Batch {
197             const IndirectDrawMetadata::IndexedIndirectValidationBatch* metadata;
198             uint64_t numIndexBufferElements;
199             uint64_t dataBufferOffset;
200             uint64_t dataSize;
201             uint64_t clientIndirectOffset;
202             uint64_t clientIndirectSize;
203             uint64_t validatedParamsOffset;
204             uint64_t validatedParamsSize;
205             BatchInfo* batchInfo;
206         };
207 
208         struct Pass {
209             BufferBase* clientIndirectBuffer;
210             uint64_t validatedParamsSize = 0;
211             uint64_t batchDataSize = 0;
212             std::unique_ptr<void, void (*)(void*)> batchData{nullptr, std::free};
213             std::vector<Batch> batches;
214         };
215 
216         // First stage is grouping all batches into passes. We try to pack as many batches into a
217         // single pass as possible. Batches can be grouped together as long as they're validating
218         // data from the same indirect buffer, but they may still be split into multiple passes if
219         // the number of draw calls in a pass would exceed some (very high) upper bound.
220         size_t validatedParamsSize = 0;
221         std::vector<Pass> passes;
222         IndirectDrawMetadata::IndexedIndirectBufferValidationInfoMap& bufferInfoMap =
223             *indirectDrawMetadata->GetIndexedIndirectBufferValidationInfo();
224         if (bufferInfoMap.empty()) {
225             return {};
226         }
227 
228         const uint32_t maxStorageBufferBindingSize =
229             device->GetLimits().v1.maxStorageBufferBindingSize;
230         const uint32_t minStorageBufferOffsetAlignment =
231             device->GetLimits().v1.minStorageBufferOffsetAlignment;
232 
233         for (auto& entry : bufferInfoMap) {
234             const IndirectDrawMetadata::IndexedIndirectConfig& config = entry.first;
235             BufferBase* clientIndirectBuffer = config.first;
236             for (const IndirectDrawMetadata::IndexedIndirectValidationBatch& batch :
237                  entry.second.GetBatches()) {
238                 const uint64_t minOffsetFromAlignedBoundary =
239                     batch.minOffset % minStorageBufferOffsetAlignment;
240                 const uint64_t minOffsetAlignedDown =
241                     batch.minOffset - minOffsetFromAlignedBoundary;
242 
243                 Batch newBatch;
244                 newBatch.metadata = &batch;
245                 newBatch.numIndexBufferElements = config.second;
246                 newBatch.dataSize = GetBatchDataSize(batch.draws.size());
247                 newBatch.clientIndirectOffset = minOffsetAlignedDown;
248                 newBatch.clientIndirectSize =
249                     batch.maxOffset + kDrawIndexedIndirectSize - minOffsetAlignedDown;
250 
251                 newBatch.validatedParamsSize = batch.draws.size() * kDrawIndexedIndirectSize;
252                 newBatch.validatedParamsOffset =
253                     Align(validatedParamsSize, minStorageBufferOffsetAlignment);
254                 validatedParamsSize = newBatch.validatedParamsOffset + newBatch.validatedParamsSize;
255                 if (validatedParamsSize > maxStorageBufferBindingSize) {
256                     return DAWN_INTERNAL_ERROR("Too many drawIndexedIndirect calls to validate");
257                 }
258 
259                 Pass* currentPass = passes.empty() ? nullptr : &passes.back();
260                 if (currentPass && currentPass->clientIndirectBuffer == clientIndirectBuffer) {
261                     uint64_t nextBatchDataOffset =
262                         Align(currentPass->batchDataSize, minStorageBufferOffsetAlignment);
263                     uint64_t newPassBatchDataSize = nextBatchDataOffset + newBatch.dataSize;
264                     if (newPassBatchDataSize <= maxStorageBufferBindingSize) {
265                         // We can fit this batch in the current pass.
266                         newBatch.dataBufferOffset = nextBatchDataOffset;
267                         currentPass->batchDataSize = newPassBatchDataSize;
268                         currentPass->batches.push_back(newBatch);
269                         continue;
270                     }
271                 }
272 
273                 // We need to start a new pass for this batch.
274                 newBatch.dataBufferOffset = 0;
275 
276                 Pass newPass;
277                 newPass.clientIndirectBuffer = clientIndirectBuffer;
278                 newPass.batchDataSize = newBatch.dataSize;
279                 newPass.batches.push_back(newBatch);
280                 passes.push_back(std::move(newPass));
281             }
282         }
283 
284         auto* const store = device->GetInternalPipelineStore();
285         ScratchBuffer& validatedParamsBuffer = store->scratchIndirectStorage;
286         ScratchBuffer& batchDataBuffer = store->scratchStorage;
287 
288         uint64_t requiredBatchDataBufferSize = 0;
289         for (const Pass& pass : passes) {
290             requiredBatchDataBufferSize = std::max(requiredBatchDataBufferSize, pass.batchDataSize);
291         }
292         DAWN_TRY(batchDataBuffer.EnsureCapacity(requiredBatchDataBufferSize));
293         usageTracker->BufferUsedAs(batchDataBuffer.GetBuffer(), wgpu::BufferUsage::Storage);
294 
295         DAWN_TRY(validatedParamsBuffer.EnsureCapacity(validatedParamsSize));
296         usageTracker->BufferUsedAs(validatedParamsBuffer.GetBuffer(), wgpu::BufferUsage::Indirect);
297 
298         // Now we allocate and populate host-side batch data to be copied to the GPU.
299         for (Pass& pass : passes) {
300             // We use std::malloc here because it guarantees maximal scalar alignment.
301             pass.batchData = {std::malloc(pass.batchDataSize), std::free};
302             memset(pass.batchData.get(), 0, pass.batchDataSize);
303             uint8_t* batchData = static_cast<uint8_t*>(pass.batchData.get());
304             for (Batch& batch : pass.batches) {
305                 batch.batchInfo = new (&batchData[batch.dataBufferOffset]) BatchInfo();
306                 batch.batchInfo->numIndexBufferElements = batch.numIndexBufferElements;
307                 batch.batchInfo->numDraws = static_cast<uint32_t>(batch.metadata->draws.size());
308 
309                 uint32_t* indirectOffsets = reinterpret_cast<uint32_t*>(batch.batchInfo + 1);
310                 uint64_t validatedParamsOffset = batch.validatedParamsOffset;
311                 for (auto& draw : batch.metadata->draws) {
312                     // The shader uses this to index an array of u32, hence the division by 4 bytes.
313                     *indirectOffsets++ = static_cast<uint32_t>(
314                         (draw.clientBufferOffset - batch.clientIndirectOffset) / 4);
315 
316                     draw.cmd->indirectBuffer = validatedParamsBuffer.GetBuffer();
317                     draw.cmd->indirectOffset = validatedParamsOffset;
318 
319                     validatedParamsOffset += kDrawIndexedIndirectSize;
320                 }
321             }
322         }
323 
324         ComputePipelineBase* pipeline;
325         DAWN_TRY_ASSIGN(pipeline, GetOrCreateRenderValidationPipeline(device));
326 
327         Ref<BindGroupLayoutBase> layout;
328         DAWN_TRY_ASSIGN(layout, pipeline->GetBindGroupLayout(0));
329 
330         BindGroupEntry bindings[3];
331         BindGroupEntry& bufferDataBinding = bindings[0];
332         bufferDataBinding.binding = 0;
333         bufferDataBinding.buffer = batchDataBuffer.GetBuffer();
334 
335         BindGroupEntry& clientIndirectBinding = bindings[1];
336         clientIndirectBinding.binding = 1;
337 
338         BindGroupEntry& validatedParamsBinding = bindings[2];
339         validatedParamsBinding.binding = 2;
340         validatedParamsBinding.buffer = validatedParamsBuffer.GetBuffer();
341 
342         BindGroupDescriptor bindGroupDescriptor = {};
343         bindGroupDescriptor.layout = layout.Get();
344         bindGroupDescriptor.entryCount = 3;
345         bindGroupDescriptor.entries = bindings;
346 
347         // Finally, we can now encode our validation passes. Each pass first does a single
348         // WriteBuffer to get batch data over to the GPU, followed by a single compute pass. The
349         // compute pass encodes a separate SetBindGroup and Dispatch command for each batch.
350         for (const Pass& pass : passes) {
351             commandEncoder->APIWriteBuffer(batchDataBuffer.GetBuffer(), 0,
352                                            static_cast<const uint8_t*>(pass.batchData.get()),
353                                            pass.batchDataSize);
354 
355             // TODO(dawn:723): change to not use AcquireRef for reentrant object creation.
356             ComputePassDescriptor descriptor = {};
357             Ref<ComputePassEncoder> passEncoder =
358                 AcquireRef(commandEncoder->APIBeginComputePass(&descriptor));
359             passEncoder->APISetPipeline(pipeline);
360 
361             clientIndirectBinding.buffer = pass.clientIndirectBuffer;
362 
363             for (const Batch& batch : pass.batches) {
364                 bufferDataBinding.offset = batch.dataBufferOffset;
365                 bufferDataBinding.size = batch.dataSize;
366                 clientIndirectBinding.offset = batch.clientIndirectOffset;
367                 clientIndirectBinding.size = batch.clientIndirectSize;
368                 validatedParamsBinding.offset = batch.validatedParamsOffset;
369                 validatedParamsBinding.size = batch.validatedParamsSize;
370 
371                 Ref<BindGroupBase> bindGroup;
372                 DAWN_TRY_ASSIGN(bindGroup, device->CreateBindGroup(&bindGroupDescriptor));
373 
374                 const uint32_t numDrawsRoundedUp =
375                     (batch.batchInfo->numDraws + kWorkgroupSize - 1) / kWorkgroupSize;
376                 passEncoder->APISetBindGroup(0, bindGroup.Get());
377                 passEncoder->APIDispatch(numDrawsRoundedUp);
378             }
379 
380             passEncoder->APIEndPass();
381         }
382 
383         return {};
384     }
385 
386 }  // namespace dawn_native
387