• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2017 The Dawn Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "dawn_native/ShaderModule.h"
16 
17 #include "common/HashUtils.h"
18 #include "dawn_native/BindGroupLayout.h"
19 #include "dawn_native/Device.h"
20 #include "dawn_native/Pipeline.h"
21 #include "dawn_native/PipelineLayout.h"
22 
23 #include <spirv-cross/spirv_cross.hpp>
24 #include <spirv-tools/libspirv.hpp>
25 
26 #include <sstream>
27 
28 namespace dawn_native {
29 
ValidateShaderModuleDescriptor(DeviceBase *,const ShaderModuleDescriptor * descriptor)30     MaybeError ValidateShaderModuleDescriptor(DeviceBase*,
31                                               const ShaderModuleDescriptor* descriptor) {
32         if (descriptor->nextInChain != nullptr) {
33             return DAWN_VALIDATION_ERROR("nextInChain must be nullptr");
34         }
35 
36         spvtools::SpirvTools spirvTools(SPV_ENV_VULKAN_1_1);
37 
38         std::ostringstream errorStream;
39         errorStream << "SPIRV Validation failure:" << std::endl;
40 
41         spirvTools.SetMessageConsumer([&errorStream](spv_message_level_t level, const char*,
42                                                      const spv_position_t& position,
43                                                      const char* message) {
44             switch (level) {
45                 case SPV_MSG_FATAL:
46                 case SPV_MSG_INTERNAL_ERROR:
47                 case SPV_MSG_ERROR:
48                     errorStream << "error: line " << position.index << ": " << message << std::endl;
49                     break;
50                 case SPV_MSG_WARNING:
51                     errorStream << "warning: line " << position.index << ": " << message
52                                 << std::endl;
53                     break;
54                 case SPV_MSG_INFO:
55                     errorStream << "info: line " << position.index << ": " << message << std::endl;
56                     break;
57                 default:
58                     break;
59             }
60         });
61 
62         if (!spirvTools.Validate(descriptor->code, descriptor->codeSize)) {
63             return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
64         }
65 
66         return {};
67     }
68 
69     // ShaderModuleBase
70 
ShaderModuleBase(DeviceBase * device,const ShaderModuleDescriptor * descriptor,bool blueprint)71     ShaderModuleBase::ShaderModuleBase(DeviceBase* device,
72                                        const ShaderModuleDescriptor* descriptor,
73                                        bool blueprint)
74         : ObjectBase(device),
75           mCode(descriptor->code, descriptor->code + descriptor->codeSize),
76           mIsBlueprint(blueprint) {
77     }
78 
ShaderModuleBase(DeviceBase * device,ObjectBase::ErrorTag tag)79     ShaderModuleBase::ShaderModuleBase(DeviceBase* device, ObjectBase::ErrorTag tag)
80         : ObjectBase(device, tag) {
81     }
82 
~ShaderModuleBase()83     ShaderModuleBase::~ShaderModuleBase() {
84         // Do not uncache the actual cached object if we are a blueprint
85         if (!mIsBlueprint && !IsError()) {
86             GetDevice()->UncacheShaderModule(this);
87         }
88     }
89 
90     // static
MakeError(DeviceBase * device)91     ShaderModuleBase* ShaderModuleBase::MakeError(DeviceBase* device) {
92         return new ShaderModuleBase(device, ObjectBase::kError);
93     }
94 
ExtractSpirvInfo(const spirv_cross::Compiler & compiler)95     void ShaderModuleBase::ExtractSpirvInfo(const spirv_cross::Compiler& compiler) {
96         ASSERT(!IsError());
97 
98         DeviceBase* device = GetDevice();
99         // TODO(cwallez@chromium.org): make errors here creation errors
100         // currently errors here do not prevent the shadermodule from being used
101         const auto& resources = compiler.get_shader_resources();
102 
103         switch (compiler.get_execution_model()) {
104             case spv::ExecutionModelVertex:
105                 mExecutionModel = ShaderStage::Vertex;
106                 break;
107             case spv::ExecutionModelFragment:
108                 mExecutionModel = ShaderStage::Fragment;
109                 break;
110             case spv::ExecutionModelGLCompute:
111                 mExecutionModel = ShaderStage::Compute;
112                 break;
113             default:
114                 UNREACHABLE();
115         }
116 
117         if (resources.push_constant_buffers.size() > 0) {
118             GetDevice()->HandleError("Push constants aren't supported.");
119         }
120 
121         // Fill in bindingInfo with the SPIRV bindings
122         auto ExtractResourcesBinding = [this](const spirv_cross::SmallVector<spirv_cross::Resource>&
123                                                   resources,
124                                               const spirv_cross::Compiler& compiler,
125                                               dawn::BindingType bindingType) {
126             for (const auto& resource : resources) {
127                 ASSERT(compiler.get_decoration_bitset(resource.id).get(spv::DecorationBinding));
128                 ASSERT(
129                     compiler.get_decoration_bitset(resource.id).get(spv::DecorationDescriptorSet));
130 
131                 uint32_t binding = compiler.get_decoration(resource.id, spv::DecorationBinding);
132                 uint32_t set = compiler.get_decoration(resource.id, spv::DecorationDescriptorSet);
133 
134                 if (binding >= kMaxBindingsPerGroup || set >= kMaxBindGroups) {
135                     GetDevice()->HandleError("Binding over limits in the SPIRV");
136                     continue;
137                 }
138 
139                 auto& info = mBindingInfo[set][binding];
140                 info.used = true;
141                 info.id = resource.id;
142                 info.base_type_id = resource.base_type_id;
143                 info.type = bindingType;
144             }
145         };
146 
147         ExtractResourcesBinding(resources.uniform_buffers, compiler,
148                                 dawn::BindingType::UniformBuffer);
149         ExtractResourcesBinding(resources.separate_images, compiler,
150                                 dawn::BindingType::SampledTexture);
151         ExtractResourcesBinding(resources.separate_samplers, compiler, dawn::BindingType::Sampler);
152         ExtractResourcesBinding(resources.storage_buffers, compiler,
153                                 dawn::BindingType::StorageBuffer);
154 
155         // Extract the vertex attributes
156         if (mExecutionModel == ShaderStage::Vertex) {
157             for (const auto& attrib : resources.stage_inputs) {
158                 ASSERT(compiler.get_decoration_bitset(attrib.id).get(spv::DecorationLocation));
159                 uint32_t location = compiler.get_decoration(attrib.id, spv::DecorationLocation);
160 
161                 if (location >= kMaxVertexAttributes) {
162                     device->HandleError("Attribute location over limits in the SPIRV");
163                     return;
164                 }
165 
166                 mUsedVertexAttributes.set(location);
167             }
168 
169             // Without a location qualifier on vertex outputs, spirv_cross::CompilerMSL gives them
170             // all the location 0, causing a compile error.
171             for (const auto& attrib : resources.stage_outputs) {
172                 if (!compiler.get_decoration_bitset(attrib.id).get(spv::DecorationLocation)) {
173                     device->HandleError("Need location qualifier on vertex output");
174                     return;
175                 }
176             }
177         }
178 
179         if (mExecutionModel == ShaderStage::Fragment) {
180             // Without a location qualifier on vertex inputs, spirv_cross::CompilerMSL gives them
181             // all the location 0, causing a compile error.
182             for (const auto& attrib : resources.stage_inputs) {
183                 if (!compiler.get_decoration_bitset(attrib.id).get(spv::DecorationLocation)) {
184                     device->HandleError("Need location qualifier on fragment input");
185                     return;
186                 }
187             }
188         }
189     }
190 
GetBindingInfo() const191     const ShaderModuleBase::ModuleBindingInfo& ShaderModuleBase::GetBindingInfo() const {
192         ASSERT(!IsError());
193         return mBindingInfo;
194     }
195 
GetUsedVertexAttributes() const196     const std::bitset<kMaxVertexAttributes>& ShaderModuleBase::GetUsedVertexAttributes() const {
197         ASSERT(!IsError());
198         return mUsedVertexAttributes;
199     }
200 
GetExecutionModel() const201     ShaderStage ShaderModuleBase::GetExecutionModel() const {
202         ASSERT(!IsError());
203         return mExecutionModel;
204     }
205 
IsCompatibleWithPipelineLayout(const PipelineLayoutBase * layout)206     bool ShaderModuleBase::IsCompatibleWithPipelineLayout(const PipelineLayoutBase* layout) {
207         ASSERT(!IsError());
208 
209         for (uint32_t group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
210             if (!IsCompatibleWithBindGroupLayout(group, layout->GetBindGroupLayout(group))) {
211                 return false;
212             }
213         }
214 
215         for (uint32_t group : IterateBitSet(~layout->GetBindGroupLayoutsMask())) {
216             for (size_t i = 0; i < kMaxBindingsPerGroup; ++i) {
217                 if (mBindingInfo[group][i].used) {
218                     return false;
219                 }
220             }
221         }
222 
223         return true;
224     }
225 
IsCompatibleWithBindGroupLayout(size_t group,const BindGroupLayoutBase * layout)226     bool ShaderModuleBase::IsCompatibleWithBindGroupLayout(size_t group,
227                                                            const BindGroupLayoutBase* layout) {
228         ASSERT(!IsError());
229 
230         const auto& layoutInfo = layout->GetBindingInfo();
231         for (size_t i = 0; i < kMaxBindingsPerGroup; ++i) {
232             const auto& moduleInfo = mBindingInfo[group][i];
233             const auto& layoutBindingType = layoutInfo.types[i];
234 
235             if (!moduleInfo.used) {
236                 continue;
237             }
238 
239             if (layoutBindingType != moduleInfo.type) {
240                 return false;
241             }
242 
243             if ((layoutInfo.visibilities[i] & StageBit(mExecutionModel)) == 0) {
244                 return false;
245             }
246         }
247 
248         return true;
249     }
250 
operator ()(const ShaderModuleBase * module) const251     size_t ShaderModuleBase::HashFunc::operator()(const ShaderModuleBase* module) const {
252         size_t hash = 0;
253 
254         for (uint32_t word : module->mCode) {
255             HashCombine(&hash, word);
256         }
257 
258         return hash;
259     }
260 
operator ()(const ShaderModuleBase * a,const ShaderModuleBase * b) const261     bool ShaderModuleBase::EqualityFunc::operator()(const ShaderModuleBase* a,
262                                                     const ShaderModuleBase* b) const {
263         return a->mCode == b->mCode;
264     }
265 
266 }  // namespace dawn_native
267