• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2018 The Dawn Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "dawn_native/ComputePassEncoder.h"
16 
17 #include "dawn_native/BindGroup.h"
18 #include "dawn_native/BindGroupLayout.h"
19 #include "dawn_native/Buffer.h"
20 #include "dawn_native/CommandEncoder.h"
21 #include "dawn_native/CommandValidation.h"
22 #include "dawn_native/Commands.h"
23 #include "dawn_native/ComputePipeline.h"
24 #include "dawn_native/Device.h"
25 #include "dawn_native/InternalPipelineStore.h"
26 #include "dawn_native/ObjectType_autogen.h"
27 #include "dawn_native/PassResourceUsageTracker.h"
28 #include "dawn_native/QuerySet.h"
29 #include "dawn_native/utils/WGPUHelpers.h"
30 
31 namespace dawn_native {
32 
33     namespace {
34 
GetOrCreateIndirectDispatchValidationPipeline(DeviceBase * device)35         ResultOrError<ComputePipelineBase*> GetOrCreateIndirectDispatchValidationPipeline(
36             DeviceBase* device) {
37             InternalPipelineStore* store = device->GetInternalPipelineStore();
38 
39             if (store->dispatchIndirectValidationPipeline != nullptr) {
40                 return store->dispatchIndirectValidationPipeline.Get();
41             }
42 
43             // TODO(https://crbug.com/dawn/1108): Propagate validation feedback from this
44             // shader in various failure modes.
45             // Type 'bool' cannot be used in storage class 'uniform' as it is non-host-shareable.
46             Ref<ShaderModuleBase> shaderModule;
47             DAWN_TRY_ASSIGN(shaderModule, utils::CreateShaderModule(device, R"(
48                 [[block]] struct UniformParams {
49                     maxComputeWorkgroupsPerDimension: u32;
50                     clientOffsetInU32: u32;
51                     enableValidation: u32;
52                     duplicateNumWorkgroups: u32;
53                 };
54 
55                 [[block]] struct IndirectParams {
56                     data: array<u32>;
57                 };
58 
59                 [[block]] struct ValidatedParams {
60                     data: array<u32>;
61                 };
62 
63                 [[group(0), binding(0)]] var<uniform> uniformParams: UniformParams;
64                 [[group(0), binding(1)]] var<storage, read_write> clientParams: IndirectParams;
65                 [[group(0), binding(2)]] var<storage, write> validatedParams: ValidatedParams;
66 
67                 [[stage(compute), workgroup_size(1, 1, 1)]]
68                 fn main() {
69                     for (var i = 0u; i < 3u; i = i + 1u) {
70                         var numWorkgroups = clientParams.data[uniformParams.clientOffsetInU32 + i];
71                         if (uniformParams.enableValidation > 0u &&
72                             numWorkgroups > uniformParams.maxComputeWorkgroupsPerDimension) {
73                             numWorkgroups = 0u;
74                         }
75                         validatedParams.data[i] = numWorkgroups;
76 
77                         if (uniformParams.duplicateNumWorkgroups > 0u) {
78                              validatedParams.data[i + 3u] = numWorkgroups;
79                         }
80                     }
81                 }
82             )"));
83 
84             Ref<BindGroupLayoutBase> bindGroupLayout;
85             DAWN_TRY_ASSIGN(
86                 bindGroupLayout,
87                 utils::MakeBindGroupLayout(
88                     device,
89                     {
90                         {0, wgpu::ShaderStage::Compute, wgpu::BufferBindingType::Uniform},
91                         {1, wgpu::ShaderStage::Compute, kInternalStorageBufferBinding},
92                         {2, wgpu::ShaderStage::Compute, wgpu::BufferBindingType::Storage},
93                     },
94                     /* allowInternalBinding */ true));
95 
96             Ref<PipelineLayoutBase> pipelineLayout;
97             DAWN_TRY_ASSIGN(pipelineLayout,
98                             utils::MakeBasicPipelineLayout(device, bindGroupLayout));
99 
100             ComputePipelineDescriptor computePipelineDescriptor = {};
101             computePipelineDescriptor.layout = pipelineLayout.Get();
102             computePipelineDescriptor.compute.module = shaderModule.Get();
103             computePipelineDescriptor.compute.entryPoint = "main";
104 
105             DAWN_TRY_ASSIGN(store->dispatchIndirectValidationPipeline,
106                             device->CreateComputePipeline(&computePipelineDescriptor));
107 
108             return store->dispatchIndirectValidationPipeline.Get();
109         }
110 
111     }  // namespace
112 
ComputePassEncoder(DeviceBase * device,const ComputePassDescriptor * descriptor,CommandEncoder * commandEncoder,EncodingContext * encodingContext)113     ComputePassEncoder::ComputePassEncoder(DeviceBase* device,
114                                            const ComputePassDescriptor* descriptor,
115                                            CommandEncoder* commandEncoder,
116                                            EncodingContext* encodingContext)
117         : ProgrammableEncoder(device, descriptor->label, encodingContext),
118           mCommandEncoder(commandEncoder) {
119         TrackInDevice();
120     }
121 
ComputePassEncoder(DeviceBase * device,CommandEncoder * commandEncoder,EncodingContext * encodingContext,ErrorTag errorTag)122     ComputePassEncoder::ComputePassEncoder(DeviceBase* device,
123                                            CommandEncoder* commandEncoder,
124                                            EncodingContext* encodingContext,
125                                            ErrorTag errorTag)
126         : ProgrammableEncoder(device, encodingContext, errorTag), mCommandEncoder(commandEncoder) {
127     }
128 
MakeError(DeviceBase * device,CommandEncoder * commandEncoder,EncodingContext * encodingContext)129     ComputePassEncoder* ComputePassEncoder::MakeError(DeviceBase* device,
130                                                       CommandEncoder* commandEncoder,
131                                                       EncodingContext* encodingContext) {
132         return new ComputePassEncoder(device, commandEncoder, encodingContext, ObjectBase::kError);
133     }
134 
DestroyImpl()135     void ComputePassEncoder::DestroyImpl() {
136         // Ensure that the pass has exited. This is done for passes only since validation requires
137         // they exit before destruction while bundles do not.
138         mEncodingContext->EnsurePassExited(this);
139     }
140 
GetType() const141     ObjectType ComputePassEncoder::GetType() const {
142         return ObjectType::ComputePassEncoder;
143     }
144 
APIEndPass()145     void ComputePassEncoder::APIEndPass() {
146         if (mEncodingContext->TryEncode(
147                 this,
148                 [&](CommandAllocator* allocator) -> MaybeError {
149                     if (IsValidationEnabled()) {
150                         DAWN_TRY(ValidateProgrammableEncoderEnd());
151                     }
152 
153                     allocator->Allocate<EndComputePassCmd>(Command::EndComputePass);
154 
155                     return {};
156                 },
157                 "encoding %s.EndPass().", this)) {
158             mEncodingContext->ExitComputePass(this, mUsageTracker.AcquireResourceUsage());
159         }
160     }
161 
APIDispatch(uint32_t x,uint32_t y,uint32_t z)162     void ComputePassEncoder::APIDispatch(uint32_t x, uint32_t y, uint32_t z) {
163         mEncodingContext->TryEncode(
164             this,
165             [&](CommandAllocator* allocator) -> MaybeError {
166                 if (IsValidationEnabled()) {
167                     DAWN_TRY(mCommandBufferState.ValidateCanDispatch());
168 
169                     uint32_t workgroupsPerDimension =
170                         GetDevice()->GetLimits().v1.maxComputeWorkgroupsPerDimension;
171 
172                     DAWN_INVALID_IF(
173                         x > workgroupsPerDimension,
174                         "Dispatch size X (%u) exceeds max compute workgroups per dimension (%u).",
175                         x, workgroupsPerDimension);
176 
177                     DAWN_INVALID_IF(
178                         y > workgroupsPerDimension,
179                         "Dispatch size Y (%u) exceeds max compute workgroups per dimension (%u).",
180                         y, workgroupsPerDimension);
181 
182                     DAWN_INVALID_IF(
183                         z > workgroupsPerDimension,
184                         "Dispatch size Z (%u) exceeds max compute workgroups per dimension (%u).",
185                         z, workgroupsPerDimension);
186                 }
187 
188                 // Record the synchronization scope for Dispatch, which is just the current
189                 // bindgroups.
190                 AddDispatchSyncScope();
191 
192                 DispatchCmd* dispatch = allocator->Allocate<DispatchCmd>(Command::Dispatch);
193                 dispatch->x = x;
194                 dispatch->y = y;
195                 dispatch->z = z;
196 
197                 return {};
198             },
199             "encoding %s.Dispatch(%u, %u, %u).", this, x, y, z);
200     }
201 
202     ResultOrError<std::pair<Ref<BufferBase>, uint64_t>>
TransformIndirectDispatchBuffer(Ref<BufferBase> indirectBuffer,uint64_t indirectOffset)203     ComputePassEncoder::TransformIndirectDispatchBuffer(Ref<BufferBase> indirectBuffer,
204                                                         uint64_t indirectOffset) {
205         DeviceBase* device = GetDevice();
206 
207         const bool shouldDuplicateNumWorkgroups =
208             device->ShouldDuplicateNumWorkgroupsForDispatchIndirect(
209                 mCommandBufferState.GetComputePipeline());
210         if (!IsValidationEnabled() && !shouldDuplicateNumWorkgroups) {
211             return std::make_pair(indirectBuffer, indirectOffset);
212         }
213 
214         // Save the previous command buffer state so it can be restored after the
215         // validation inserts additional commands.
216         CommandBufferStateTracker previousState = mCommandBufferState;
217 
218         auto* const store = device->GetInternalPipelineStore();
219 
220         Ref<ComputePipelineBase> validationPipeline;
221         DAWN_TRY_ASSIGN(validationPipeline, GetOrCreateIndirectDispatchValidationPipeline(device));
222 
223         Ref<BindGroupLayoutBase> layout;
224         DAWN_TRY_ASSIGN(layout, validationPipeline->GetBindGroupLayout(0));
225 
226         uint32_t storageBufferOffsetAlignment =
227             device->GetLimits().v1.minStorageBufferOffsetAlignment;
228 
229         // Let the offset be the indirectOffset, aligned down to |storageBufferOffsetAlignment|.
230         const uint32_t clientOffsetFromAlignedBoundary =
231             indirectOffset % storageBufferOffsetAlignment;
232         const uint64_t clientOffsetAlignedDown = indirectOffset - clientOffsetFromAlignedBoundary;
233         const uint64_t clientIndirectBindingOffset = clientOffsetAlignedDown;
234 
235         // Let the size of the binding be the additional offset, plus the size.
236         const uint64_t clientIndirectBindingSize =
237             kDispatchIndirectSize + clientOffsetFromAlignedBoundary;
238 
239         // Neither 'enableValidation' nor 'duplicateNumWorkgroups' can be declared as 'bool' as
240         // currently in WGSL type 'bool' cannot be used in storage class 'uniform' as 'it is
241         // non-host-shareable'.
242         struct UniformParams {
243             uint32_t maxComputeWorkgroupsPerDimension;
244             uint32_t clientOffsetInU32;
245             uint32_t enableValidation;
246             uint32_t duplicateNumWorkgroups;
247         };
248 
249         // Create a uniform buffer to hold parameters for the shader.
250         Ref<BufferBase> uniformBuffer;
251         {
252             UniformParams params;
253             params.maxComputeWorkgroupsPerDimension =
254                 device->GetLimits().v1.maxComputeWorkgroupsPerDimension;
255             params.clientOffsetInU32 = clientOffsetFromAlignedBoundary / sizeof(uint32_t);
256             params.enableValidation = static_cast<uint32_t>(IsValidationEnabled());
257             params.duplicateNumWorkgroups = static_cast<uint32_t>(shouldDuplicateNumWorkgroups);
258 
259             DAWN_TRY_ASSIGN(uniformBuffer, utils::CreateBufferFromData(
260                                                device, wgpu::BufferUsage::Uniform, {params}));
261         }
262 
263         // Reserve space in the scratch buffer to hold the validated indirect params.
264         ScratchBuffer& scratchBuffer = store->scratchIndirectStorage;
265         const uint64_t scratchBufferSize =
266             shouldDuplicateNumWorkgroups ? 2 * kDispatchIndirectSize : kDispatchIndirectSize;
267         DAWN_TRY(scratchBuffer.EnsureCapacity(scratchBufferSize));
268         Ref<BufferBase> validatedIndirectBuffer = scratchBuffer.GetBuffer();
269 
270         Ref<BindGroupBase> validationBindGroup;
271         ASSERT(indirectBuffer->GetUsage() & kInternalStorageBuffer);
272         DAWN_TRY_ASSIGN(validationBindGroup,
273                         utils::MakeBindGroup(device, layout,
274                                              {
275                                                  {0, uniformBuffer},
276                                                  {1, indirectBuffer, clientIndirectBindingOffset,
277                                                   clientIndirectBindingSize},
278                                                  {2, validatedIndirectBuffer, 0, scratchBufferSize},
279                                              }));
280 
281         // Issue commands to validate the indirect buffer.
282         APISetPipeline(validationPipeline.Get());
283         APISetBindGroup(0, validationBindGroup.Get());
284         APIDispatch(1);
285 
286         // Restore the state.
287         RestoreCommandBufferState(std::move(previousState));
288 
289         // Return the new indirect buffer and indirect buffer offset.
290         return std::make_pair(std::move(validatedIndirectBuffer), uint64_t(0));
291     }
292 
APIDispatchIndirect(BufferBase * indirectBuffer,uint64_t indirectOffset)293     void ComputePassEncoder::APIDispatchIndirect(BufferBase* indirectBuffer,
294                                                  uint64_t indirectOffset) {
295         mEncodingContext->TryEncode(
296             this,
297             [&](CommandAllocator* allocator) -> MaybeError {
298                 if (IsValidationEnabled()) {
299                     DAWN_TRY(GetDevice()->ValidateObject(indirectBuffer));
300                     DAWN_TRY(ValidateCanUseAs(indirectBuffer, wgpu::BufferUsage::Indirect));
301                     DAWN_TRY(mCommandBufferState.ValidateCanDispatch());
302 
303                     DAWN_INVALID_IF(indirectOffset % 4 != 0,
304                                     "Indirect offset (%u) is not a multiple of 4.", indirectOffset);
305 
306                     DAWN_INVALID_IF(
307                         indirectOffset >= indirectBuffer->GetSize() ||
308                             indirectOffset + kDispatchIndirectSize > indirectBuffer->GetSize(),
309                         "Indirect offset (%u) and dispatch size (%u) exceeds the indirect buffer "
310                         "size (%u).",
311                         indirectOffset, kDispatchIndirectSize, indirectBuffer->GetSize());
312                 }
313 
314                 SyncScopeUsageTracker scope;
315                 scope.BufferUsedAs(indirectBuffer, wgpu::BufferUsage::Indirect);
316                 mUsageTracker.AddReferencedBuffer(indirectBuffer);
317                 // TODO(crbug.com/dawn/1166): If validation is enabled, adding |indirectBuffer|
318                 // is needed for correct usage validation even though it will only be bound for
319                 // storage. This will unecessarily transition the |indirectBuffer| in
320                 // the backend.
321 
322                 Ref<BufferBase> indirectBufferRef = indirectBuffer;
323 
324                 // Get applied indirect buffer with necessary changes on the original indirect
325                 // buffer. For example,
326                 // - Validate each indirect dispatch with a single dispatch to copy the indirect
327                 //   buffer params into a scratch buffer if they're valid, and otherwise zero them
328                 //   out.
329                 // - Duplicate all the indirect dispatch parameters to support [[num_workgroups]] on
330                 //   D3D12.
331                 // - Directly return the original indirect dispatch buffer if we don't need any
332                 //   transformations on it.
333                 // We could consider moving the validation earlier in the pass after the last
334                 // last point the indirect buffer was used with writable usage, as well as batch
335                 // validation for multiple dispatches into one, but inserting commands at
336                 // arbitrary points in the past is not possible right now.
337                 DAWN_TRY_ASSIGN(std::tie(indirectBufferRef, indirectOffset),
338                                 TransformIndirectDispatchBuffer(indirectBufferRef, indirectOffset));
339 
340                 // If we have created a new scratch dispatch indirect buffer in
341                 // TransformIndirectDispatchBuffer(), we need to track it in mUsageTracker.
342                 if (indirectBufferRef.Get() != indirectBuffer) {
343                     // |indirectBufferRef| was replaced with a scratch buffer. Add it to the
344                     // synchronization scope.
345                     scope.BufferUsedAs(indirectBufferRef.Get(), wgpu::BufferUsage::Indirect);
346                     mUsageTracker.AddReferencedBuffer(indirectBufferRef.Get());
347                 }
348 
349                 AddDispatchSyncScope(std::move(scope));
350 
351                 DispatchIndirectCmd* dispatch =
352                     allocator->Allocate<DispatchIndirectCmd>(Command::DispatchIndirect);
353                 dispatch->indirectBuffer = std::move(indirectBufferRef);
354                 dispatch->indirectOffset = indirectOffset;
355                 return {};
356             },
357             "encoding %s.DispatchIndirect(%s, %u).", this, indirectBuffer, indirectOffset);
358     }
359 
APISetPipeline(ComputePipelineBase * pipeline)360     void ComputePassEncoder::APISetPipeline(ComputePipelineBase* pipeline) {
361         mEncodingContext->TryEncode(
362             this,
363             [&](CommandAllocator* allocator) -> MaybeError {
364                 if (IsValidationEnabled()) {
365                     DAWN_TRY(GetDevice()->ValidateObject(pipeline));
366                 }
367 
368                 mCommandBufferState.SetComputePipeline(pipeline);
369 
370                 SetComputePipelineCmd* cmd =
371                     allocator->Allocate<SetComputePipelineCmd>(Command::SetComputePipeline);
372                 cmd->pipeline = pipeline;
373 
374                 return {};
375             },
376             "encoding %s.SetPipeline(%s).", this, pipeline);
377     }
378 
APISetBindGroup(uint32_t groupIndexIn,BindGroupBase * group,uint32_t dynamicOffsetCount,const uint32_t * dynamicOffsets)379     void ComputePassEncoder::APISetBindGroup(uint32_t groupIndexIn,
380                                              BindGroupBase* group,
381                                              uint32_t dynamicOffsetCount,
382                                              const uint32_t* dynamicOffsets) {
383         mEncodingContext->TryEncode(
384             this,
385             [&](CommandAllocator* allocator) -> MaybeError {
386                 BindGroupIndex groupIndex(groupIndexIn);
387 
388                 if (IsValidationEnabled()) {
389                     DAWN_TRY(ValidateSetBindGroup(groupIndex, group, dynamicOffsetCount,
390                                                   dynamicOffsets));
391                 }
392 
393                 mUsageTracker.AddResourcesReferencedByBindGroup(group);
394                 RecordSetBindGroup(allocator, groupIndex, group, dynamicOffsetCount,
395                                    dynamicOffsets);
396                 mCommandBufferState.SetBindGroup(groupIndex, group, dynamicOffsetCount,
397                                                  dynamicOffsets);
398 
399                 return {};
400             },
401             "encoding %s.SetBindGroup(%u, %s, %u, ...).", this, groupIndexIn, group,
402             dynamicOffsetCount);
403     }
404 
APIWriteTimestamp(QuerySetBase * querySet,uint32_t queryIndex)405     void ComputePassEncoder::APIWriteTimestamp(QuerySetBase* querySet, uint32_t queryIndex) {
406         mEncodingContext->TryEncode(
407             this,
408             [&](CommandAllocator* allocator) -> MaybeError {
409                 if (IsValidationEnabled()) {
410                     DAWN_TRY(GetDevice()->ValidateObject(querySet));
411                     DAWN_TRY(ValidateTimestampQuery(querySet, queryIndex));
412                 }
413 
414                 mCommandEncoder->TrackQueryAvailability(querySet, queryIndex);
415 
416                 WriteTimestampCmd* cmd =
417                     allocator->Allocate<WriteTimestampCmd>(Command::WriteTimestamp);
418                 cmd->querySet = querySet;
419                 cmd->queryIndex = queryIndex;
420 
421                 return {};
422             },
423             "encoding %s.WriteTimestamp(%s, %u).", this, querySet, queryIndex);
424     }
425 
AddDispatchSyncScope(SyncScopeUsageTracker scope)426     void ComputePassEncoder::AddDispatchSyncScope(SyncScopeUsageTracker scope) {
427         PipelineLayoutBase* layout = mCommandBufferState.GetPipelineLayout();
428         for (BindGroupIndex i : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
429             scope.AddBindGroup(mCommandBufferState.GetBindGroup(i));
430         }
431         mUsageTracker.AddDispatch(scope.AcquireSyncScopeUsage());
432     }
433 
RestoreCommandBufferState(CommandBufferStateTracker state)434     void ComputePassEncoder::RestoreCommandBufferState(CommandBufferStateTracker state) {
435         // Encode commands for the backend to restore the pipeline and bind groups.
436         if (state.HasPipeline()) {
437             APISetPipeline(state.GetComputePipeline());
438         }
439         for (BindGroupIndex i(0); i < kMaxBindGroupsTyped; ++i) {
440             BindGroupBase* bg = state.GetBindGroup(i);
441             if (bg != nullptr) {
442                 const std::vector<uint32_t>& offsets = state.GetDynamicOffsets(i);
443                 if (offsets.empty()) {
444                     APISetBindGroup(static_cast<uint32_t>(i), bg);
445                 } else {
446                     APISetBindGroup(static_cast<uint32_t>(i), bg, offsets.size(), offsets.data());
447                 }
448             }
449         }
450 
451         // Restore the frontend state tracking information.
452         mCommandBufferState = std::move(state);
453     }
454 
GetCommandBufferStateTrackerForTesting()455     CommandBufferStateTracker* ComputePassEncoder::GetCommandBufferStateTrackerForTesting() {
456         return &mCommandBufferState;
457     }
458 
459 }  // namespace dawn_native
460