1 // Copyright 2017 The Dawn Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #include "dawn_native/d3d12/ComputePipelineD3D12.h" 16 17 #include "dawn_native/CreatePipelineAsyncTask.h" 18 #include "dawn_native/d3d12/D3D12Error.h" 19 #include "dawn_native/d3d12/DeviceD3D12.h" 20 #include "dawn_native/d3d12/PipelineLayoutD3D12.h" 21 #include "dawn_native/d3d12/PlatformFunctions.h" 22 #include "dawn_native/d3d12/ShaderModuleD3D12.h" 23 #include "dawn_native/d3d12/UtilsD3D12.h" 24 25 namespace dawn_native { namespace d3d12 { 26 CreateUninitialized(Device * device,const ComputePipelineDescriptor * descriptor)27 Ref<ComputePipeline> ComputePipeline::CreateUninitialized( 28 Device* device, 29 const ComputePipelineDescriptor* descriptor) { 30 return AcquireRef(new ComputePipeline(device, descriptor)); 31 } 32 Initialize()33 MaybeError ComputePipeline::Initialize() { 34 Device* device = ToBackend(GetDevice()); 35 uint32_t compileFlags = 0; 36 37 if (!device->IsToggleEnabled(Toggle::UseDXC) && 38 !device->IsToggleEnabled(Toggle::FxcOptimizations)) { 39 compileFlags |= D3DCOMPILE_OPTIMIZATION_LEVEL0; 40 } 41 42 if (device->IsToggleEnabled(Toggle::EmitHLSLDebugSymbols)) { 43 compileFlags |= D3DCOMPILE_DEBUG | D3DCOMPILE_SKIP_OPTIMIZATION; 44 } 45 46 // SPRIV-cross does matrix multiplication expecting row major matrices 47 compileFlags |= D3DCOMPILE_PACK_MATRIX_ROW_MAJOR; 48 49 const ProgrammableStage& computeStage = GetStage(SingleShaderStage::Compute); 50 ShaderModule* module = ToBackend(computeStage.module.Get()); 51 52 D3D12_COMPUTE_PIPELINE_STATE_DESC d3dDesc = {}; 53 d3dDesc.pRootSignature = ToBackend(GetLayout())->GetRootSignature(); 54 55 CompiledShader compiledShader; 56 DAWN_TRY_ASSIGN(compiledShader, module->Compile(computeStage, SingleShaderStage::Compute, 57 ToBackend(GetLayout()), compileFlags)); 58 d3dDesc.CS = compiledShader.GetD3D12ShaderBytecode(); 59 auto* d3d12Device = device->GetD3D12Device(); 60 DAWN_TRY(CheckHRESULT( 61 d3d12Device->CreateComputePipelineState(&d3dDesc, IID_PPV_ARGS(&mPipelineState)), 62 "D3D12 creating pipeline state")); 63 64 SetLabelImpl(); 65 66 return {}; 67 } 68 69 ComputePipeline::~ComputePipeline() = default; 70 DestroyImpl()71 void ComputePipeline::DestroyImpl() { 72 ComputePipelineBase::DestroyImpl(); 73 ToBackend(GetDevice())->ReferenceUntilUnused(mPipelineState); 74 } 75 GetPipelineState() const76 ID3D12PipelineState* ComputePipeline::GetPipelineState() const { 77 return mPipelineState.Get(); 78 } 79 SetLabelImpl()80 void ComputePipeline::SetLabelImpl() { 81 SetDebugName(ToBackend(GetDevice()), GetPipelineState(), "Dawn_ComputePipeline", 82 GetLabel()); 83 } 84 InitializeAsync(Ref<ComputePipelineBase> computePipeline,WGPUCreateComputePipelineAsyncCallback callback,void * userdata)85 void ComputePipeline::InitializeAsync(Ref<ComputePipelineBase> computePipeline, 86 WGPUCreateComputePipelineAsyncCallback callback, 87 void* userdata) { 88 std::unique_ptr<CreateComputePipelineAsyncTask> asyncTask = 89 std::make_unique<CreateComputePipelineAsyncTask>(std::move(computePipeline), callback, 90 userdata); 91 CreateComputePipelineAsyncTask::RunAsync(std::move(asyncTask)); 92 } 93 UsesNumWorkgroups() const94 bool ComputePipeline::UsesNumWorkgroups() const { 95 return GetStage(SingleShaderStage::Compute).metadata->usesNumWorkgroups; 96 } 97 GetDispatchIndirectCommandSignature()98 ComPtr<ID3D12CommandSignature> ComputePipeline::GetDispatchIndirectCommandSignature() { 99 if (UsesNumWorkgroups()) { 100 return ToBackend(GetLayout())->GetDispatchIndirectCommandSignatureWithNumWorkgroups(); 101 } 102 return ToBackend(GetDevice())->GetDispatchIndirectSignature(); 103 } 104 105 }} // namespace dawn_native::d3d12 106