1 // Copyright 2017 The Dawn Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #include "dawn_native/ShaderModule.h" 16 17 #include "common/HashUtils.h" 18 #include "dawn_native/BindGroupLayout.h" 19 #include "dawn_native/Device.h" 20 #include "dawn_native/Pipeline.h" 21 #include "dawn_native/PipelineLayout.h" 22 23 #include <spirv-cross/spirv_cross.hpp> 24 #include <spirv-tools/libspirv.hpp> 25 26 #include <sstream> 27 28 namespace dawn_native { 29 ValidateShaderModuleDescriptor(DeviceBase *,const ShaderModuleDescriptor * descriptor)30 MaybeError ValidateShaderModuleDescriptor(DeviceBase*, 31 const ShaderModuleDescriptor* descriptor) { 32 if (descriptor->nextInChain != nullptr) { 33 return DAWN_VALIDATION_ERROR("nextInChain must be nullptr"); 34 } 35 36 spvtools::SpirvTools spirvTools(SPV_ENV_VULKAN_1_1); 37 38 std::ostringstream errorStream; 39 errorStream << "SPIRV Validation failure:" << std::endl; 40 41 spirvTools.SetMessageConsumer([&errorStream](spv_message_level_t level, const char*, 42 const spv_position_t& position, 43 const char* message) { 44 switch (level) { 45 case SPV_MSG_FATAL: 46 case SPV_MSG_INTERNAL_ERROR: 47 case SPV_MSG_ERROR: 48 errorStream << "error: line " << position.index << ": " << message << std::endl; 49 break; 50 case SPV_MSG_WARNING: 51 errorStream << "warning: line " << position.index << ": " << message 52 << std::endl; 53 break; 54 case SPV_MSG_INFO: 55 errorStream << "info: line " << position.index << ": " << message << std::endl; 56 break; 57 default: 58 break; 59 } 60 }); 61 62 if (!spirvTools.Validate(descriptor->code, descriptor->codeSize)) { 63 return DAWN_VALIDATION_ERROR(errorStream.str().c_str()); 64 } 65 66 return {}; 67 } 68 69 // ShaderModuleBase 70 ShaderModuleBase(DeviceBase * device,const ShaderModuleDescriptor * descriptor,bool blueprint)71 ShaderModuleBase::ShaderModuleBase(DeviceBase* device, 72 const ShaderModuleDescriptor* descriptor, 73 bool blueprint) 74 : ObjectBase(device), 75 mCode(descriptor->code, descriptor->code + descriptor->codeSize), 76 mIsBlueprint(blueprint) { 77 } 78 ShaderModuleBase(DeviceBase * device,ObjectBase::ErrorTag tag)79 ShaderModuleBase::ShaderModuleBase(DeviceBase* device, ObjectBase::ErrorTag tag) 80 : ObjectBase(device, tag) { 81 } 82 ~ShaderModuleBase()83 ShaderModuleBase::~ShaderModuleBase() { 84 // Do not uncache the actual cached object if we are a blueprint 85 if (!mIsBlueprint && !IsError()) { 86 GetDevice()->UncacheShaderModule(this); 87 } 88 } 89 90 // static MakeError(DeviceBase * device)91 ShaderModuleBase* ShaderModuleBase::MakeError(DeviceBase* device) { 92 return new ShaderModuleBase(device, ObjectBase::kError); 93 } 94 ExtractSpirvInfo(const spirv_cross::Compiler & compiler)95 void ShaderModuleBase::ExtractSpirvInfo(const spirv_cross::Compiler& compiler) { 96 ASSERT(!IsError()); 97 98 DeviceBase* device = GetDevice(); 99 // TODO(cwallez@chromium.org): make errors here creation errors 100 // currently errors here do not prevent the shadermodule from being used 101 const auto& resources = compiler.get_shader_resources(); 102 103 switch (compiler.get_execution_model()) { 104 case spv::ExecutionModelVertex: 105 mExecutionModel = ShaderStage::Vertex; 106 break; 107 case spv::ExecutionModelFragment: 108 mExecutionModel = ShaderStage::Fragment; 109 break; 110 case spv::ExecutionModelGLCompute: 111 mExecutionModel = ShaderStage::Compute; 112 break; 113 default: 114 UNREACHABLE(); 115 } 116 117 if (resources.push_constant_buffers.size() > 0) { 118 GetDevice()->HandleError("Push constants aren't supported."); 119 } 120 121 // Fill in bindingInfo with the SPIRV bindings 122 auto ExtractResourcesBinding = [this](const spirv_cross::SmallVector<spirv_cross::Resource>& 123 resources, 124 const spirv_cross::Compiler& compiler, 125 dawn::BindingType bindingType) { 126 for (const auto& resource : resources) { 127 ASSERT(compiler.get_decoration_bitset(resource.id).get(spv::DecorationBinding)); 128 ASSERT( 129 compiler.get_decoration_bitset(resource.id).get(spv::DecorationDescriptorSet)); 130 131 uint32_t binding = compiler.get_decoration(resource.id, spv::DecorationBinding); 132 uint32_t set = compiler.get_decoration(resource.id, spv::DecorationDescriptorSet); 133 134 if (binding >= kMaxBindingsPerGroup || set >= kMaxBindGroups) { 135 GetDevice()->HandleError("Binding over limits in the SPIRV"); 136 continue; 137 } 138 139 auto& info = mBindingInfo[set][binding]; 140 info.used = true; 141 info.id = resource.id; 142 info.base_type_id = resource.base_type_id; 143 info.type = bindingType; 144 } 145 }; 146 147 ExtractResourcesBinding(resources.uniform_buffers, compiler, 148 dawn::BindingType::UniformBuffer); 149 ExtractResourcesBinding(resources.separate_images, compiler, 150 dawn::BindingType::SampledTexture); 151 ExtractResourcesBinding(resources.separate_samplers, compiler, dawn::BindingType::Sampler); 152 ExtractResourcesBinding(resources.storage_buffers, compiler, 153 dawn::BindingType::StorageBuffer); 154 155 // Extract the vertex attributes 156 if (mExecutionModel == ShaderStage::Vertex) { 157 for (const auto& attrib : resources.stage_inputs) { 158 ASSERT(compiler.get_decoration_bitset(attrib.id).get(spv::DecorationLocation)); 159 uint32_t location = compiler.get_decoration(attrib.id, spv::DecorationLocation); 160 161 if (location >= kMaxVertexAttributes) { 162 device->HandleError("Attribute location over limits in the SPIRV"); 163 return; 164 } 165 166 mUsedVertexAttributes.set(location); 167 } 168 169 // Without a location qualifier on vertex outputs, spirv_cross::CompilerMSL gives them 170 // all the location 0, causing a compile error. 171 for (const auto& attrib : resources.stage_outputs) { 172 if (!compiler.get_decoration_bitset(attrib.id).get(spv::DecorationLocation)) { 173 device->HandleError("Need location qualifier on vertex output"); 174 return; 175 } 176 } 177 } 178 179 if (mExecutionModel == ShaderStage::Fragment) { 180 // Without a location qualifier on vertex inputs, spirv_cross::CompilerMSL gives them 181 // all the location 0, causing a compile error. 182 for (const auto& attrib : resources.stage_inputs) { 183 if (!compiler.get_decoration_bitset(attrib.id).get(spv::DecorationLocation)) { 184 device->HandleError("Need location qualifier on fragment input"); 185 return; 186 } 187 } 188 } 189 } 190 GetBindingInfo() const191 const ShaderModuleBase::ModuleBindingInfo& ShaderModuleBase::GetBindingInfo() const { 192 ASSERT(!IsError()); 193 return mBindingInfo; 194 } 195 GetUsedVertexAttributes() const196 const std::bitset<kMaxVertexAttributes>& ShaderModuleBase::GetUsedVertexAttributes() const { 197 ASSERT(!IsError()); 198 return mUsedVertexAttributes; 199 } 200 GetExecutionModel() const201 ShaderStage ShaderModuleBase::GetExecutionModel() const { 202 ASSERT(!IsError()); 203 return mExecutionModel; 204 } 205 IsCompatibleWithPipelineLayout(const PipelineLayoutBase * layout)206 bool ShaderModuleBase::IsCompatibleWithPipelineLayout(const PipelineLayoutBase* layout) { 207 ASSERT(!IsError()); 208 209 for (uint32_t group : IterateBitSet(layout->GetBindGroupLayoutsMask())) { 210 if (!IsCompatibleWithBindGroupLayout(group, layout->GetBindGroupLayout(group))) { 211 return false; 212 } 213 } 214 215 for (uint32_t group : IterateBitSet(~layout->GetBindGroupLayoutsMask())) { 216 for (size_t i = 0; i < kMaxBindingsPerGroup; ++i) { 217 if (mBindingInfo[group][i].used) { 218 return false; 219 } 220 } 221 } 222 223 return true; 224 } 225 IsCompatibleWithBindGroupLayout(size_t group,const BindGroupLayoutBase * layout)226 bool ShaderModuleBase::IsCompatibleWithBindGroupLayout(size_t group, 227 const BindGroupLayoutBase* layout) { 228 ASSERT(!IsError()); 229 230 const auto& layoutInfo = layout->GetBindingInfo(); 231 for (size_t i = 0; i < kMaxBindingsPerGroup; ++i) { 232 const auto& moduleInfo = mBindingInfo[group][i]; 233 const auto& layoutBindingType = layoutInfo.types[i]; 234 235 if (!moduleInfo.used) { 236 continue; 237 } 238 239 if (layoutBindingType != moduleInfo.type) { 240 return false; 241 } 242 243 if ((layoutInfo.visibilities[i] & StageBit(mExecutionModel)) == 0) { 244 return false; 245 } 246 } 247 248 return true; 249 } 250 operator ()(const ShaderModuleBase * module) const251 size_t ShaderModuleBase::HashFunc::operator()(const ShaderModuleBase* module) const { 252 size_t hash = 0; 253 254 for (uint32_t word : module->mCode) { 255 HashCombine(&hash, word); 256 } 257 258 return hash; 259 } 260 operator ()(const ShaderModuleBase * a,const ShaderModuleBase * b) const261 bool ShaderModuleBase::EqualityFunc::operator()(const ShaderModuleBase* a, 262 const ShaderModuleBase* b) const { 263 return a->mCode == b->mCode; 264 } 265 266 } // namespace dawn_native 267