1 // Copyright 2018 The Dawn Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #include "dawn_native/ComputePassEncoder.h" 16 17 #include "dawn_native/BindGroup.h" 18 #include "dawn_native/BindGroupLayout.h" 19 #include "dawn_native/Buffer.h" 20 #include "dawn_native/CommandEncoder.h" 21 #include "dawn_native/CommandValidation.h" 22 #include "dawn_native/Commands.h" 23 #include "dawn_native/ComputePipeline.h" 24 #include "dawn_native/Device.h" 25 #include "dawn_native/InternalPipelineStore.h" 26 #include "dawn_native/ObjectType_autogen.h" 27 #include "dawn_native/PassResourceUsageTracker.h" 28 #include "dawn_native/QuerySet.h" 29 #include "dawn_native/utils/WGPUHelpers.h" 30 31 namespace dawn_native { 32 33 namespace { 34 GetOrCreateIndirectDispatchValidationPipeline(DeviceBase * device)35 ResultOrError<ComputePipelineBase*> GetOrCreateIndirectDispatchValidationPipeline( 36 DeviceBase* device) { 37 InternalPipelineStore* store = device->GetInternalPipelineStore(); 38 39 if (store->dispatchIndirectValidationPipeline != nullptr) { 40 return store->dispatchIndirectValidationPipeline.Get(); 41 } 42 43 // TODO(https://crbug.com/dawn/1108): Propagate validation feedback from this 44 // shader in various failure modes. 45 // Type 'bool' cannot be used in storage class 'uniform' as it is non-host-shareable. 46 Ref<ShaderModuleBase> shaderModule; 47 DAWN_TRY_ASSIGN(shaderModule, utils::CreateShaderModule(device, R"( 48 [[block]] struct UniformParams { 49 maxComputeWorkgroupsPerDimension: u32; 50 clientOffsetInU32: u32; 51 enableValidation: u32; 52 duplicateNumWorkgroups: u32; 53 }; 54 55 [[block]] struct IndirectParams { 56 data: array<u32>; 57 }; 58 59 [[block]] struct ValidatedParams { 60 data: array<u32>; 61 }; 62 63 [[group(0), binding(0)]] var<uniform> uniformParams: UniformParams; 64 [[group(0), binding(1)]] var<storage, read_write> clientParams: IndirectParams; 65 [[group(0), binding(2)]] var<storage, write> validatedParams: ValidatedParams; 66 67 [[stage(compute), workgroup_size(1, 1, 1)]] 68 fn main() { 69 for (var i = 0u; i < 3u; i = i + 1u) { 70 var numWorkgroups = clientParams.data[uniformParams.clientOffsetInU32 + i]; 71 if (uniformParams.enableValidation > 0u && 72 numWorkgroups > uniformParams.maxComputeWorkgroupsPerDimension) { 73 numWorkgroups = 0u; 74 } 75 validatedParams.data[i] = numWorkgroups; 76 77 if (uniformParams.duplicateNumWorkgroups > 0u) { 78 validatedParams.data[i + 3u] = numWorkgroups; 79 } 80 } 81 } 82 )")); 83 84 Ref<BindGroupLayoutBase> bindGroupLayout; 85 DAWN_TRY_ASSIGN( 86 bindGroupLayout, 87 utils::MakeBindGroupLayout( 88 device, 89 { 90 {0, wgpu::ShaderStage::Compute, wgpu::BufferBindingType::Uniform}, 91 {1, wgpu::ShaderStage::Compute, kInternalStorageBufferBinding}, 92 {2, wgpu::ShaderStage::Compute, wgpu::BufferBindingType::Storage}, 93 }, 94 /* allowInternalBinding */ true)); 95 96 Ref<PipelineLayoutBase> pipelineLayout; 97 DAWN_TRY_ASSIGN(pipelineLayout, 98 utils::MakeBasicPipelineLayout(device, bindGroupLayout)); 99 100 ComputePipelineDescriptor computePipelineDescriptor = {}; 101 computePipelineDescriptor.layout = pipelineLayout.Get(); 102 computePipelineDescriptor.compute.module = shaderModule.Get(); 103 computePipelineDescriptor.compute.entryPoint = "main"; 104 105 DAWN_TRY_ASSIGN(store->dispatchIndirectValidationPipeline, 106 device->CreateComputePipeline(&computePipelineDescriptor)); 107 108 return store->dispatchIndirectValidationPipeline.Get(); 109 } 110 111 } // namespace 112 ComputePassEncoder(DeviceBase * device,const ComputePassDescriptor * descriptor,CommandEncoder * commandEncoder,EncodingContext * encodingContext)113 ComputePassEncoder::ComputePassEncoder(DeviceBase* device, 114 const ComputePassDescriptor* descriptor, 115 CommandEncoder* commandEncoder, 116 EncodingContext* encodingContext) 117 : ProgrammableEncoder(device, descriptor->label, encodingContext), 118 mCommandEncoder(commandEncoder) { 119 TrackInDevice(); 120 } 121 ComputePassEncoder(DeviceBase * device,CommandEncoder * commandEncoder,EncodingContext * encodingContext,ErrorTag errorTag)122 ComputePassEncoder::ComputePassEncoder(DeviceBase* device, 123 CommandEncoder* commandEncoder, 124 EncodingContext* encodingContext, 125 ErrorTag errorTag) 126 : ProgrammableEncoder(device, encodingContext, errorTag), mCommandEncoder(commandEncoder) { 127 } 128 MakeError(DeviceBase * device,CommandEncoder * commandEncoder,EncodingContext * encodingContext)129 ComputePassEncoder* ComputePassEncoder::MakeError(DeviceBase* device, 130 CommandEncoder* commandEncoder, 131 EncodingContext* encodingContext) { 132 return new ComputePassEncoder(device, commandEncoder, encodingContext, ObjectBase::kError); 133 } 134 DestroyImpl()135 void ComputePassEncoder::DestroyImpl() { 136 // Ensure that the pass has exited. This is done for passes only since validation requires 137 // they exit before destruction while bundles do not. 138 mEncodingContext->EnsurePassExited(this); 139 } 140 GetType() const141 ObjectType ComputePassEncoder::GetType() const { 142 return ObjectType::ComputePassEncoder; 143 } 144 APIEndPass()145 void ComputePassEncoder::APIEndPass() { 146 if (mEncodingContext->TryEncode( 147 this, 148 [&](CommandAllocator* allocator) -> MaybeError { 149 if (IsValidationEnabled()) { 150 DAWN_TRY(ValidateProgrammableEncoderEnd()); 151 } 152 153 allocator->Allocate<EndComputePassCmd>(Command::EndComputePass); 154 155 return {}; 156 }, 157 "encoding %s.EndPass().", this)) { 158 mEncodingContext->ExitComputePass(this, mUsageTracker.AcquireResourceUsage()); 159 } 160 } 161 APIDispatch(uint32_t x,uint32_t y,uint32_t z)162 void ComputePassEncoder::APIDispatch(uint32_t x, uint32_t y, uint32_t z) { 163 mEncodingContext->TryEncode( 164 this, 165 [&](CommandAllocator* allocator) -> MaybeError { 166 if (IsValidationEnabled()) { 167 DAWN_TRY(mCommandBufferState.ValidateCanDispatch()); 168 169 uint32_t workgroupsPerDimension = 170 GetDevice()->GetLimits().v1.maxComputeWorkgroupsPerDimension; 171 172 DAWN_INVALID_IF( 173 x > workgroupsPerDimension, 174 "Dispatch size X (%u) exceeds max compute workgroups per dimension (%u).", 175 x, workgroupsPerDimension); 176 177 DAWN_INVALID_IF( 178 y > workgroupsPerDimension, 179 "Dispatch size Y (%u) exceeds max compute workgroups per dimension (%u).", 180 y, workgroupsPerDimension); 181 182 DAWN_INVALID_IF( 183 z > workgroupsPerDimension, 184 "Dispatch size Z (%u) exceeds max compute workgroups per dimension (%u).", 185 z, workgroupsPerDimension); 186 } 187 188 // Record the synchronization scope for Dispatch, which is just the current 189 // bindgroups. 190 AddDispatchSyncScope(); 191 192 DispatchCmd* dispatch = allocator->Allocate<DispatchCmd>(Command::Dispatch); 193 dispatch->x = x; 194 dispatch->y = y; 195 dispatch->z = z; 196 197 return {}; 198 }, 199 "encoding %s.Dispatch(%u, %u, %u).", this, x, y, z); 200 } 201 202 ResultOrError<std::pair<Ref<BufferBase>, uint64_t>> TransformIndirectDispatchBuffer(Ref<BufferBase> indirectBuffer,uint64_t indirectOffset)203 ComputePassEncoder::TransformIndirectDispatchBuffer(Ref<BufferBase> indirectBuffer, 204 uint64_t indirectOffset) { 205 DeviceBase* device = GetDevice(); 206 207 const bool shouldDuplicateNumWorkgroups = 208 device->ShouldDuplicateNumWorkgroupsForDispatchIndirect( 209 mCommandBufferState.GetComputePipeline()); 210 if (!IsValidationEnabled() && !shouldDuplicateNumWorkgroups) { 211 return std::make_pair(indirectBuffer, indirectOffset); 212 } 213 214 // Save the previous command buffer state so it can be restored after the 215 // validation inserts additional commands. 216 CommandBufferStateTracker previousState = mCommandBufferState; 217 218 auto* const store = device->GetInternalPipelineStore(); 219 220 Ref<ComputePipelineBase> validationPipeline; 221 DAWN_TRY_ASSIGN(validationPipeline, GetOrCreateIndirectDispatchValidationPipeline(device)); 222 223 Ref<BindGroupLayoutBase> layout; 224 DAWN_TRY_ASSIGN(layout, validationPipeline->GetBindGroupLayout(0)); 225 226 uint32_t storageBufferOffsetAlignment = 227 device->GetLimits().v1.minStorageBufferOffsetAlignment; 228 229 // Let the offset be the indirectOffset, aligned down to |storageBufferOffsetAlignment|. 230 const uint32_t clientOffsetFromAlignedBoundary = 231 indirectOffset % storageBufferOffsetAlignment; 232 const uint64_t clientOffsetAlignedDown = indirectOffset - clientOffsetFromAlignedBoundary; 233 const uint64_t clientIndirectBindingOffset = clientOffsetAlignedDown; 234 235 // Let the size of the binding be the additional offset, plus the size. 236 const uint64_t clientIndirectBindingSize = 237 kDispatchIndirectSize + clientOffsetFromAlignedBoundary; 238 239 // Neither 'enableValidation' nor 'duplicateNumWorkgroups' can be declared as 'bool' as 240 // currently in WGSL type 'bool' cannot be used in storage class 'uniform' as 'it is 241 // non-host-shareable'. 242 struct UniformParams { 243 uint32_t maxComputeWorkgroupsPerDimension; 244 uint32_t clientOffsetInU32; 245 uint32_t enableValidation; 246 uint32_t duplicateNumWorkgroups; 247 }; 248 249 // Create a uniform buffer to hold parameters for the shader. 250 Ref<BufferBase> uniformBuffer; 251 { 252 UniformParams params; 253 params.maxComputeWorkgroupsPerDimension = 254 device->GetLimits().v1.maxComputeWorkgroupsPerDimension; 255 params.clientOffsetInU32 = clientOffsetFromAlignedBoundary / sizeof(uint32_t); 256 params.enableValidation = static_cast<uint32_t>(IsValidationEnabled()); 257 params.duplicateNumWorkgroups = static_cast<uint32_t>(shouldDuplicateNumWorkgroups); 258 259 DAWN_TRY_ASSIGN(uniformBuffer, utils::CreateBufferFromData( 260 device, wgpu::BufferUsage::Uniform, {params})); 261 } 262 263 // Reserve space in the scratch buffer to hold the validated indirect params. 264 ScratchBuffer& scratchBuffer = store->scratchIndirectStorage; 265 const uint64_t scratchBufferSize = 266 shouldDuplicateNumWorkgroups ? 2 * kDispatchIndirectSize : kDispatchIndirectSize; 267 DAWN_TRY(scratchBuffer.EnsureCapacity(scratchBufferSize)); 268 Ref<BufferBase> validatedIndirectBuffer = scratchBuffer.GetBuffer(); 269 270 Ref<BindGroupBase> validationBindGroup; 271 ASSERT(indirectBuffer->GetUsage() & kInternalStorageBuffer); 272 DAWN_TRY_ASSIGN(validationBindGroup, 273 utils::MakeBindGroup(device, layout, 274 { 275 {0, uniformBuffer}, 276 {1, indirectBuffer, clientIndirectBindingOffset, 277 clientIndirectBindingSize}, 278 {2, validatedIndirectBuffer, 0, scratchBufferSize}, 279 })); 280 281 // Issue commands to validate the indirect buffer. 282 APISetPipeline(validationPipeline.Get()); 283 APISetBindGroup(0, validationBindGroup.Get()); 284 APIDispatch(1); 285 286 // Restore the state. 287 RestoreCommandBufferState(std::move(previousState)); 288 289 // Return the new indirect buffer and indirect buffer offset. 290 return std::make_pair(std::move(validatedIndirectBuffer), uint64_t(0)); 291 } 292 APIDispatchIndirect(BufferBase * indirectBuffer,uint64_t indirectOffset)293 void ComputePassEncoder::APIDispatchIndirect(BufferBase* indirectBuffer, 294 uint64_t indirectOffset) { 295 mEncodingContext->TryEncode( 296 this, 297 [&](CommandAllocator* allocator) -> MaybeError { 298 if (IsValidationEnabled()) { 299 DAWN_TRY(GetDevice()->ValidateObject(indirectBuffer)); 300 DAWN_TRY(ValidateCanUseAs(indirectBuffer, wgpu::BufferUsage::Indirect)); 301 DAWN_TRY(mCommandBufferState.ValidateCanDispatch()); 302 303 DAWN_INVALID_IF(indirectOffset % 4 != 0, 304 "Indirect offset (%u) is not a multiple of 4.", indirectOffset); 305 306 DAWN_INVALID_IF( 307 indirectOffset >= indirectBuffer->GetSize() || 308 indirectOffset + kDispatchIndirectSize > indirectBuffer->GetSize(), 309 "Indirect offset (%u) and dispatch size (%u) exceeds the indirect buffer " 310 "size (%u).", 311 indirectOffset, kDispatchIndirectSize, indirectBuffer->GetSize()); 312 } 313 314 SyncScopeUsageTracker scope; 315 scope.BufferUsedAs(indirectBuffer, wgpu::BufferUsage::Indirect); 316 mUsageTracker.AddReferencedBuffer(indirectBuffer); 317 // TODO(crbug.com/dawn/1166): If validation is enabled, adding |indirectBuffer| 318 // is needed for correct usage validation even though it will only be bound for 319 // storage. This will unecessarily transition the |indirectBuffer| in 320 // the backend. 321 322 Ref<BufferBase> indirectBufferRef = indirectBuffer; 323 324 // Get applied indirect buffer with necessary changes on the original indirect 325 // buffer. For example, 326 // - Validate each indirect dispatch with a single dispatch to copy the indirect 327 // buffer params into a scratch buffer if they're valid, and otherwise zero them 328 // out. 329 // - Duplicate all the indirect dispatch parameters to support [[num_workgroups]] on 330 // D3D12. 331 // - Directly return the original indirect dispatch buffer if we don't need any 332 // transformations on it. 333 // We could consider moving the validation earlier in the pass after the last 334 // last point the indirect buffer was used with writable usage, as well as batch 335 // validation for multiple dispatches into one, but inserting commands at 336 // arbitrary points in the past is not possible right now. 337 DAWN_TRY_ASSIGN(std::tie(indirectBufferRef, indirectOffset), 338 TransformIndirectDispatchBuffer(indirectBufferRef, indirectOffset)); 339 340 // If we have created a new scratch dispatch indirect buffer in 341 // TransformIndirectDispatchBuffer(), we need to track it in mUsageTracker. 342 if (indirectBufferRef.Get() != indirectBuffer) { 343 // |indirectBufferRef| was replaced with a scratch buffer. Add it to the 344 // synchronization scope. 345 scope.BufferUsedAs(indirectBufferRef.Get(), wgpu::BufferUsage::Indirect); 346 mUsageTracker.AddReferencedBuffer(indirectBufferRef.Get()); 347 } 348 349 AddDispatchSyncScope(std::move(scope)); 350 351 DispatchIndirectCmd* dispatch = 352 allocator->Allocate<DispatchIndirectCmd>(Command::DispatchIndirect); 353 dispatch->indirectBuffer = std::move(indirectBufferRef); 354 dispatch->indirectOffset = indirectOffset; 355 return {}; 356 }, 357 "encoding %s.DispatchIndirect(%s, %u).", this, indirectBuffer, indirectOffset); 358 } 359 APISetPipeline(ComputePipelineBase * pipeline)360 void ComputePassEncoder::APISetPipeline(ComputePipelineBase* pipeline) { 361 mEncodingContext->TryEncode( 362 this, 363 [&](CommandAllocator* allocator) -> MaybeError { 364 if (IsValidationEnabled()) { 365 DAWN_TRY(GetDevice()->ValidateObject(pipeline)); 366 } 367 368 mCommandBufferState.SetComputePipeline(pipeline); 369 370 SetComputePipelineCmd* cmd = 371 allocator->Allocate<SetComputePipelineCmd>(Command::SetComputePipeline); 372 cmd->pipeline = pipeline; 373 374 return {}; 375 }, 376 "encoding %s.SetPipeline(%s).", this, pipeline); 377 } 378 APISetBindGroup(uint32_t groupIndexIn,BindGroupBase * group,uint32_t dynamicOffsetCount,const uint32_t * dynamicOffsets)379 void ComputePassEncoder::APISetBindGroup(uint32_t groupIndexIn, 380 BindGroupBase* group, 381 uint32_t dynamicOffsetCount, 382 const uint32_t* dynamicOffsets) { 383 mEncodingContext->TryEncode( 384 this, 385 [&](CommandAllocator* allocator) -> MaybeError { 386 BindGroupIndex groupIndex(groupIndexIn); 387 388 if (IsValidationEnabled()) { 389 DAWN_TRY(ValidateSetBindGroup(groupIndex, group, dynamicOffsetCount, 390 dynamicOffsets)); 391 } 392 393 mUsageTracker.AddResourcesReferencedByBindGroup(group); 394 RecordSetBindGroup(allocator, groupIndex, group, dynamicOffsetCount, 395 dynamicOffsets); 396 mCommandBufferState.SetBindGroup(groupIndex, group, dynamicOffsetCount, 397 dynamicOffsets); 398 399 return {}; 400 }, 401 "encoding %s.SetBindGroup(%u, %s, %u, ...).", this, groupIndexIn, group, 402 dynamicOffsetCount); 403 } 404 APIWriteTimestamp(QuerySetBase * querySet,uint32_t queryIndex)405 void ComputePassEncoder::APIWriteTimestamp(QuerySetBase* querySet, uint32_t queryIndex) { 406 mEncodingContext->TryEncode( 407 this, 408 [&](CommandAllocator* allocator) -> MaybeError { 409 if (IsValidationEnabled()) { 410 DAWN_TRY(GetDevice()->ValidateObject(querySet)); 411 DAWN_TRY(ValidateTimestampQuery(querySet, queryIndex)); 412 } 413 414 mCommandEncoder->TrackQueryAvailability(querySet, queryIndex); 415 416 WriteTimestampCmd* cmd = 417 allocator->Allocate<WriteTimestampCmd>(Command::WriteTimestamp); 418 cmd->querySet = querySet; 419 cmd->queryIndex = queryIndex; 420 421 return {}; 422 }, 423 "encoding %s.WriteTimestamp(%s, %u).", this, querySet, queryIndex); 424 } 425 AddDispatchSyncScope(SyncScopeUsageTracker scope)426 void ComputePassEncoder::AddDispatchSyncScope(SyncScopeUsageTracker scope) { 427 PipelineLayoutBase* layout = mCommandBufferState.GetPipelineLayout(); 428 for (BindGroupIndex i : IterateBitSet(layout->GetBindGroupLayoutsMask())) { 429 scope.AddBindGroup(mCommandBufferState.GetBindGroup(i)); 430 } 431 mUsageTracker.AddDispatch(scope.AcquireSyncScopeUsage()); 432 } 433 RestoreCommandBufferState(CommandBufferStateTracker state)434 void ComputePassEncoder::RestoreCommandBufferState(CommandBufferStateTracker state) { 435 // Encode commands for the backend to restore the pipeline and bind groups. 436 if (state.HasPipeline()) { 437 APISetPipeline(state.GetComputePipeline()); 438 } 439 for (BindGroupIndex i(0); i < kMaxBindGroupsTyped; ++i) { 440 BindGroupBase* bg = state.GetBindGroup(i); 441 if (bg != nullptr) { 442 const std::vector<uint32_t>& offsets = state.GetDynamicOffsets(i); 443 if (offsets.empty()) { 444 APISetBindGroup(static_cast<uint32_t>(i), bg); 445 } else { 446 APISetBindGroup(static_cast<uint32_t>(i), bg, offsets.size(), offsets.data()); 447 } 448 } 449 } 450 451 // Restore the frontend state tracking information. 452 mCommandBufferState = std::move(state); 453 } 454 GetCommandBufferStateTrackerForTesting()455 CommandBufferStateTracker* ComputePassEncoder::GetCommandBufferStateTrackerForTesting() { 456 return &mCommandBufferState; 457 } 458 459 } // namespace dawn_native 460