• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2017 The Dawn Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "dawn_native/Pipeline.h"
16 
17 #include "dawn_native/BindGroupLayout.h"
18 #include "dawn_native/Device.h"
19 #include "dawn_native/ObjectBase.h"
20 #include "dawn_native/ObjectContentHasher.h"
21 #include "dawn_native/PipelineLayout.h"
22 #include "dawn_native/ShaderModule.h"
23 
24 namespace dawn_native {
ValidateProgrammableStage(DeviceBase * device,const ShaderModuleBase * module,const std::string & entryPoint,uint32_t constantCount,const ConstantEntry * constants,const PipelineLayoutBase * layout,SingleShaderStage stage)25     MaybeError ValidateProgrammableStage(DeviceBase* device,
26                                          const ShaderModuleBase* module,
27                                          const std::string& entryPoint,
28                                          uint32_t constantCount,
29                                          const ConstantEntry* constants,
30                                          const PipelineLayoutBase* layout,
31                                          SingleShaderStage stage) {
32         DAWN_TRY(device->ValidateObject(module));
33 
34         DAWN_INVALID_IF(!module->HasEntryPoint(entryPoint),
35                         "Entry point \"%s\" doesn't exist in the shader module %s.", entryPoint,
36                         module);
37 
38         const EntryPointMetadata& metadata = module->GetEntryPoint(entryPoint);
39 
40         DAWN_INVALID_IF(metadata.stage != stage,
41                         "The stage (%s) of the entry point \"%s\" isn't the expected one (%s).",
42                         metadata.stage, entryPoint, stage);
43 
44         if (layout != nullptr) {
45             DAWN_TRY(ValidateCompatibilityWithPipelineLayout(device, metadata, layout));
46         }
47 
48         if (constantCount > 0u && device->IsToggleEnabled(Toggle::DisallowUnsafeAPIs)) {
49             return DAWN_VALIDATION_ERROR(
50                 "Pipeline overridable constants are disallowed because they are partially "
51                 "implemented.");
52         }
53 
54         // Validate if overridable constants exist in shader module
55         // pipelineBase is not yet constructed at this moment so iterate constants from descriptor
56         size_t numUninitializedConstants = metadata.uninitializedOverridableConstants.size();
57         // Keep an initialized constants sets to handle duplicate initialization cases
58         std::unordered_set<std::string> stageInitializedConstantIdentifiers;
59         for (uint32_t i = 0; i < constantCount; i++) {
60             DAWN_INVALID_IF(metadata.overridableConstants.count(constants[i].key) == 0,
61                             "Pipeline overridable constant \"%s\" not found in %s.",
62                             constants[i].key, module);
63 
64             if (stageInitializedConstantIdentifiers.count(constants[i].key) == 0) {
65                 if (metadata.uninitializedOverridableConstants.count(constants[i].key) > 0) {
66                     numUninitializedConstants--;
67                 }
68                 stageInitializedConstantIdentifiers.insert(constants[i].key);
69             } else {
70                 // There are duplicate initializations
71                 return DAWN_FORMAT_VALIDATION_ERROR(
72                     "Pipeline overridable constants \"%s\" is set more than once in %s",
73                     constants[i].key, module);
74             }
75         }
76 
77         // Validate if any overridable constant is left uninitialized
78         if (DAWN_UNLIKELY(numUninitializedConstants > 0)) {
79             std::string uninitializedConstantsArray;
80             bool isFirst = true;
81             for (std::string identifier : metadata.uninitializedOverridableConstants) {
82                 if (stageInitializedConstantIdentifiers.count(identifier) > 0) {
83                     continue;
84                 }
85 
86                 if (isFirst) {
87                     isFirst = false;
88                 } else {
89                     uninitializedConstantsArray.append(", ");
90                 }
91                 uninitializedConstantsArray.append(identifier);
92             }
93 
94             return DAWN_FORMAT_VALIDATION_ERROR(
95                 "There are uninitialized pipeline overridable constants in shader module %s, their "
96                 "identifiers:[%s]",
97                 module, uninitializedConstantsArray);
98         }
99 
100         return {};
101     }
102 
103     // PipelineBase
104 
PipelineBase(DeviceBase * device,PipelineLayoutBase * layout,const char * label,std::vector<StageAndDescriptor> stages)105     PipelineBase::PipelineBase(DeviceBase* device,
106                                PipelineLayoutBase* layout,
107                                const char* label,
108                                std::vector<StageAndDescriptor> stages)
109         : ApiObjectBase(device, label), mLayout(layout) {
110         ASSERT(!stages.empty());
111 
112         for (const StageAndDescriptor& stage : stages) {
113             // Extract argument for this stage.
114             SingleShaderStage shaderStage = stage.shaderStage;
115             ShaderModuleBase* module = stage.module;
116             const char* entryPointName = stage.entryPoint.c_str();
117 
118             const EntryPointMetadata& metadata = module->GetEntryPoint(entryPointName);
119             ASSERT(metadata.stage == shaderStage);
120 
121             // Record them internally.
122             bool isFirstStage = mStageMask == wgpu::ShaderStage::None;
123             mStageMask |= StageBit(shaderStage);
124             mStages[shaderStage] = {module, entryPointName, &metadata, {}};
125             auto& constants = mStages[shaderStage].constants;
126             for (uint32_t i = 0; i < stage.constantCount; i++) {
127                 constants.emplace(stage.constants[i].key, stage.constants[i].value);
128             }
129 
130             // Compute the max() of all minBufferSizes across all stages.
131             RequiredBufferSizes stageMinBufferSizes =
132                 ComputeRequiredBufferSizesForLayout(metadata, layout);
133 
134             if (isFirstStage) {
135                 mMinBufferSizes = std::move(stageMinBufferSizes);
136             } else {
137                 for (BindGroupIndex group(0); group < mMinBufferSizes.size(); ++group) {
138                     ASSERT(stageMinBufferSizes[group].size() == mMinBufferSizes[group].size());
139 
140                     for (size_t i = 0; i < stageMinBufferSizes[group].size(); ++i) {
141                         mMinBufferSizes[group][i] =
142                             std::max(mMinBufferSizes[group][i], stageMinBufferSizes[group][i]);
143                     }
144                 }
145             }
146         }
147     }
148 
PipelineBase(DeviceBase * device)149     PipelineBase::PipelineBase(DeviceBase* device) : ApiObjectBase(device, kLabelNotImplemented) {
150     }
151 
PipelineBase(DeviceBase * device,ObjectBase::ErrorTag tag)152     PipelineBase::PipelineBase(DeviceBase* device, ObjectBase::ErrorTag tag)
153         : ApiObjectBase(device, tag) {
154     }
155 
156     PipelineBase::~PipelineBase() = default;
157 
GetLayout()158     PipelineLayoutBase* PipelineBase::GetLayout() {
159         ASSERT(!IsError());
160         return mLayout.Get();
161     }
162 
GetLayout() const163     const PipelineLayoutBase* PipelineBase::GetLayout() const {
164         ASSERT(!IsError());
165         return mLayout.Get();
166     }
167 
GetMinBufferSizes() const168     const RequiredBufferSizes& PipelineBase::GetMinBufferSizes() const {
169         ASSERT(!IsError());
170         return mMinBufferSizes;
171     }
172 
GetStage(SingleShaderStage stage) const173     const ProgrammableStage& PipelineBase::GetStage(SingleShaderStage stage) const {
174         ASSERT(!IsError());
175         return mStages[stage];
176     }
177 
GetAllStages() const178     const PerStage<ProgrammableStage>& PipelineBase::GetAllStages() const {
179         return mStages;
180     }
181 
GetStageMask() const182     wgpu::ShaderStage PipelineBase::GetStageMask() const {
183         return mStageMask;
184     }
185 
ValidateGetBindGroupLayout(uint32_t groupIndex)186     MaybeError PipelineBase::ValidateGetBindGroupLayout(uint32_t groupIndex) {
187         DAWN_TRY(GetDevice()->ValidateIsAlive());
188         DAWN_TRY(GetDevice()->ValidateObject(this));
189         DAWN_TRY(GetDevice()->ValidateObject(mLayout.Get()));
190         DAWN_INVALID_IF(
191             groupIndex >= kMaxBindGroups,
192             "Bind group layout index (%u) exceeds the maximum number of bind groups (%u).",
193             groupIndex, kMaxBindGroups);
194         return {};
195     }
196 
GetBindGroupLayout(uint32_t groupIndexIn)197     ResultOrError<Ref<BindGroupLayoutBase>> PipelineBase::GetBindGroupLayout(
198         uint32_t groupIndexIn) {
199         DAWN_TRY(ValidateGetBindGroupLayout(groupIndexIn));
200 
201         BindGroupIndex groupIndex(groupIndexIn);
202         if (!mLayout->GetBindGroupLayoutsMask()[groupIndex]) {
203             return Ref<BindGroupLayoutBase>(GetDevice()->GetEmptyBindGroupLayout());
204         } else {
205             return Ref<BindGroupLayoutBase>(mLayout->GetBindGroupLayout(groupIndex));
206         }
207     }
208 
APIGetBindGroupLayout(uint32_t groupIndexIn)209     BindGroupLayoutBase* PipelineBase::APIGetBindGroupLayout(uint32_t groupIndexIn) {
210         Ref<BindGroupLayoutBase> result;
211         if (GetDevice()->ConsumedError(GetBindGroupLayout(groupIndexIn), &result,
212                                        "Validating GetBindGroupLayout (%u) on %s", groupIndexIn,
213                                        this)) {
214             return BindGroupLayoutBase::MakeError(GetDevice());
215         }
216         return result.Detach();
217     }
218 
ComputeContentHash()219     size_t PipelineBase::ComputeContentHash() {
220         ObjectContentHasher recorder;
221         recorder.Record(mLayout->GetContentHash());
222 
223         recorder.Record(mStageMask);
224         for (SingleShaderStage stage : IterateStages(mStageMask)) {
225             recorder.Record(mStages[stage].module->GetContentHash());
226             recorder.Record(mStages[stage].entryPoint);
227         }
228 
229         return recorder.GetContentHash();
230     }
231 
232     // static
EqualForCache(const PipelineBase * a,const PipelineBase * b)233     bool PipelineBase::EqualForCache(const PipelineBase* a, const PipelineBase* b) {
234         // The layout is deduplicated so it can be compared by pointer.
235         if (a->mLayout.Get() != b->mLayout.Get() || a->mStageMask != b->mStageMask) {
236             return false;
237         }
238 
239         for (SingleShaderStage stage : IterateStages(a->mStageMask)) {
240             // The module is deduplicated so it can be compared by pointer.
241             if (a->mStages[stage].module.Get() != b->mStages[stage].module.Get() ||
242                 a->mStages[stage].entryPoint != b->mStages[stage].entryPoint) {
243                 return false;
244             }
245         }
246 
247         return true;
248     }
249 
250 }  // namespace dawn_native
251