1 // Copyright 2017 The Dawn Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef DAWNNATIVE_SHADERMODULE_H_ 16 #define DAWNNATIVE_SHADERMODULE_H_ 17 18 #include "common/Constants.h" 19 #include "common/ityp_array.h" 20 #include "dawn_native/BindingInfo.h" 21 #include "dawn_native/CachedObject.h" 22 #include "dawn_native/CompilationMessages.h" 23 #include "dawn_native/Error.h" 24 #include "dawn_native/Format.h" 25 #include "dawn_native/Forward.h" 26 #include "dawn_native/IntegerTypes.h" 27 #include "dawn_native/ObjectBase.h" 28 #include "dawn_native/PerStage.h" 29 #include "dawn_native/VertexFormat.h" 30 #include "dawn_native/dawn_platform.h" 31 32 #include <bitset> 33 #include <map> 34 #include <unordered_map> 35 #include <unordered_set> 36 #include <vector> 37 38 namespace tint { 39 40 class Program; 41 42 namespace transform { 43 class DataMap; 44 class Transform; 45 class VertexPulling; 46 } // namespace transform 47 48 } // namespace tint 49 50 namespace dawn_native { 51 52 struct EntryPointMetadata; 53 54 // Base component type of an inter-stage variable 55 enum class InterStageComponentType { 56 Sint, 57 Uint, 58 Float, 59 }; 60 61 enum class InterpolationType { 62 Perspective, 63 Linear, 64 Flat, 65 }; 66 67 enum class InterpolationSampling { 68 None, 69 Center, 70 Centroid, 71 Sample, 72 }; 73 74 using PipelineLayoutEntryPointPair = std::pair<PipelineLayoutBase*, std::string>; 75 struct PipelineLayoutEntryPointPairHashFunc { 76 size_t operator()(const PipelineLayoutEntryPointPair& pair) const; 77 }; 78 79 // A map from name to EntryPointMetadata. 80 using EntryPointMetadataTable = 81 std::unordered_map<std::string, std::unique_ptr<EntryPointMetadata>>; 82 83 // Source for a tint program 84 class TintSource; 85 86 struct ShaderModuleParseResult { 87 ShaderModuleParseResult(); 88 ~ShaderModuleParseResult(); 89 ShaderModuleParseResult(ShaderModuleParseResult&& rhs); 90 ShaderModuleParseResult& operator=(ShaderModuleParseResult&& rhs); 91 92 bool HasParsedShader() const; 93 94 std::unique_ptr<tint::Program> tintProgram; 95 std::unique_ptr<TintSource> tintSource; 96 }; 97 98 MaybeError ValidateShaderModuleDescriptor(DeviceBase* device, 99 const ShaderModuleDescriptor* descriptor, 100 ShaderModuleParseResult* parseResult, 101 OwnedCompilationMessages* outMessages); 102 MaybeError ValidateCompatibilityWithPipelineLayout(DeviceBase* device, 103 const EntryPointMetadata& entryPoint, 104 const PipelineLayoutBase* layout); 105 106 RequiredBufferSizes ComputeRequiredBufferSizesForLayout(const EntryPointMetadata& entryPoint, 107 const PipelineLayoutBase* layout); 108 ResultOrError<tint::Program> RunTransforms(tint::transform::Transform* transform, 109 const tint::Program* program, 110 const tint::transform::DataMap& inputs, 111 tint::transform::DataMap* outputs, 112 OwnedCompilationMessages* messages); 113 114 /// Creates and adds the tint::transform::VertexPulling::Config to transformInputs. 115 void AddVertexPullingTransformConfig(const RenderPipelineBase& renderPipeline, 116 const std::string& entryPoint, 117 BindGroupIndex pullingBufferBindingSet, 118 tint::transform::DataMap* transformInputs); 119 120 // Mirrors wgpu::SamplerBindingLayout but instead stores a single boolean 121 // for isComparison instead of a wgpu::SamplerBindingType enum. 122 struct ShaderSamplerBindingInfo { 123 bool isComparison; 124 }; 125 126 // Mirrors wgpu::TextureBindingLayout but instead has a set of compatible sampleTypes 127 // instead of a single enum. 128 struct ShaderTextureBindingInfo { 129 SampleTypeBit compatibleSampleTypes; 130 wgpu::TextureViewDimension viewDimension; 131 bool multisampled; 132 }; 133 134 // Per-binding shader metadata contains some SPIRV specific information in addition to 135 // most of the frontend per-binding information. 136 struct ShaderBindingInfo { 137 // The SPIRV ID of the resource. 138 uint32_t id; 139 uint32_t base_type_id; 140 141 BindingNumber binding; 142 BindingInfoType bindingType; 143 144 BufferBindingLayout buffer; 145 ShaderSamplerBindingInfo sampler; 146 ShaderTextureBindingInfo texture; 147 StorageTextureBindingLayout storageTexture; 148 }; 149 150 using BindingGroupInfoMap = std::map<BindingNumber, ShaderBindingInfo>; 151 using BindingInfoArray = ityp::array<BindGroupIndex, BindingGroupInfoMap, kMaxBindGroups>; 152 153 // The WebGPU overridable constants only support these scalar types 154 union OverridableConstantScalar { 155 // Use int32_t for boolean to initialize the full 32bit 156 int32_t b; 157 float f32; 158 int32_t i32; 159 uint32_t u32; 160 }; 161 162 // Contains all the reflection data for a valid (ShaderModule, entryPoint, stage). They are 163 // stored in the ShaderModuleBase and destroyed only when the shader program is destroyed so 164 // pointers to EntryPointMetadata are safe to store as long as you also keep a Ref to the 165 // ShaderModuleBase. 166 struct EntryPointMetadata { 167 // bindings[G][B] is the reflection data for the binding defined with 168 // [[group=G, binding=B]] in WGSL / SPIRV. 169 BindingInfoArray bindings; 170 171 struct SamplerTexturePair { 172 BindingSlot sampler; 173 BindingSlot texture; 174 }; 175 std::vector<SamplerTexturePair> samplerTexturePairs; 176 177 // The set of vertex attributes this entryPoint uses. 178 ityp::array<VertexAttributeLocation, VertexFormatBaseType, kMaxVertexAttributes> 179 vertexInputBaseTypes; 180 ityp::bitset<VertexAttributeLocation, kMaxVertexAttributes> usedVertexInputs; 181 182 // An array to record the basic types (float, int and uint) of the fragment shader outputs. 183 struct FragmentOutputVariableInfo { 184 wgpu::TextureComponentType baseType; 185 uint8_t componentCount; 186 }; 187 ityp::array<ColorAttachmentIndex, FragmentOutputVariableInfo, kMaxColorAttachments> 188 fragmentOutputVariables; 189 ityp::bitset<ColorAttachmentIndex, kMaxColorAttachments> fragmentOutputsWritten; 190 191 struct InterStageVariableInfo { 192 InterStageComponentType baseType; 193 uint32_t componentCount; 194 InterpolationType interpolationType; 195 InterpolationSampling interpolationSampling; 196 }; 197 // Now that we only support vertex and fragment stages, there can't be both inter-stage 198 // inputs and outputs in one shader stage. 199 std::bitset<kMaxInterStageShaderVariables> usedInterStageVariables; 200 std::array<InterStageVariableInfo, kMaxInterStageShaderVariables> interStageVariables; 201 202 // The local workgroup size declared for a compute entry point (or 0s otehrwise). 203 Origin3D localWorkgroupSize; 204 205 // The shader stage for this binding. 206 SingleShaderStage stage; 207 208 struct OverridableConstant { 209 uint32_t id; 210 // Match tint::inspector::OverridableConstant::Type 211 // Bool is defined as a macro on linux X11 and cannot compile 212 enum class Type { Boolean, Float32, Uint32, Int32 } type; 213 214 // If the constant doesn't not have an initializer in the shader 215 // Then it is required for the pipeline stage to have a constant record to initialize a 216 // value 217 bool isInitialized; 218 219 // Store the default initialized value in shader 220 // This is used by metal backend as the function_constant does not have dafault values 221 // Initialized when isInitialized == true 222 OverridableConstantScalar defaultValue; 223 }; 224 225 using OverridableConstantsMap = std::unordered_map<std::string, OverridableConstant>; 226 227 // Map identifier to overridable constant 228 // Identifier is unique: either the variable name or the numeric ID if specified 229 OverridableConstantsMap overridableConstants; 230 231 // Overridable constants that are not initialized in shaders 232 // They need value initialization from pipeline stage or it is a validation error 233 std::unordered_set<std::string> uninitializedOverridableConstants; 234 235 // Store constants with shader initialized values as well 236 // This is used by metal backend to set values with default initializers that are not 237 // overridden 238 std::unordered_set<std::string> initializedOverridableConstants; 239 240 bool usesNumWorkgroups = false; 241 }; 242 243 class ShaderModuleBase : public ApiObjectBase, public CachedObject { 244 public: 245 ShaderModuleBase(DeviceBase* device, 246 const ShaderModuleDescriptor* descriptor, 247 ApiObjectBase::UntrackedByDeviceTag tag); 248 ShaderModuleBase(DeviceBase* device, const ShaderModuleDescriptor* descriptor); 249 ~ShaderModuleBase() override; 250 251 static Ref<ShaderModuleBase> MakeError(DeviceBase* device); 252 253 ObjectType GetType() const override; 254 255 // Return true iff the program has an entrypoint called `entryPoint`. 256 bool HasEntryPoint(const std::string& entryPoint) const; 257 258 // Return the metadata for the given `entryPoint`. HasEntryPoint with the same argument 259 // must be true. 260 const EntryPointMetadata& GetEntryPoint(const std::string& entryPoint) const; 261 262 // Functions necessary for the unordered_set<ShaderModuleBase*>-based cache. 263 size_t ComputeContentHash() override; 264 265 struct EqualityFunc { 266 bool operator()(const ShaderModuleBase* a, const ShaderModuleBase* b) const; 267 }; 268 269 const tint::Program* GetTintProgram() const; 270 271 void APIGetCompilationInfo(wgpu::CompilationInfoCallback callback, void* userdata); 272 273 void InjectCompilationMessages( 274 std::unique_ptr<OwnedCompilationMessages> compilationMessages); 275 276 OwnedCompilationMessages* GetCompilationMessages() const; 277 278 protected: 279 // Constructor used only for mocking and testing. 280 ShaderModuleBase(DeviceBase* device); 281 void DestroyImpl() override; 282 283 MaybeError InitializeBase(ShaderModuleParseResult* parseResult); 284 285 private: 286 ShaderModuleBase(DeviceBase* device, ObjectBase::ErrorTag tag); 287 288 // The original data in the descriptor for caching. 289 enum class Type { Undefined, Spirv, Wgsl }; 290 Type mType; 291 std::vector<uint32_t> mOriginalSpirv; 292 std::string mWgsl; 293 294 EntryPointMetadataTable mEntryPoints; 295 std::unique_ptr<tint::Program> mTintProgram; 296 std::unique_ptr<TintSource> mTintSource; // Keep the tint::Source::File alive 297 298 std::unique_ptr<OwnedCompilationMessages> mCompilationMessages; 299 }; 300 301 } // namespace dawn_native 302 303 #endif // DAWNNATIVE_SHADERMODULE_H_ 304