• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2017 The Dawn Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "dawn_native/ShaderModule.h"
16 
17 #include "absl/strings/str_format.h"
18 #include "common/Constants.h"
19 #include "common/HashUtils.h"
20 #include "dawn_native/BindGroupLayout.h"
21 #include "dawn_native/ChainUtils_autogen.h"
22 #include "dawn_native/CompilationMessages.h"
23 #include "dawn_native/Device.h"
24 #include "dawn_native/ObjectContentHasher.h"
25 #include "dawn_native/Pipeline.h"
26 #include "dawn_native/PipelineLayout.h"
27 #include "dawn_native/RenderPipeline.h"
28 #include "dawn_native/TintUtils.h"
29 
30 #include <tint/tint.h>
31 
32 #include <sstream>
33 
34 namespace dawn_native {
35 
36     namespace {
37 
ToTintVertexFormat(wgpu::VertexFormat format)38         tint::transform::VertexFormat ToTintVertexFormat(wgpu::VertexFormat format) {
39             switch (format) {
40                 case wgpu::VertexFormat::Uint8x2:
41                     return tint::transform::VertexFormat::kUint8x2;
42                 case wgpu::VertexFormat::Uint8x4:
43                     return tint::transform::VertexFormat::kUint8x4;
44                 case wgpu::VertexFormat::Sint8x2:
45                     return tint::transform::VertexFormat::kSint8x2;
46                 case wgpu::VertexFormat::Sint8x4:
47                     return tint::transform::VertexFormat::kSint8x4;
48                 case wgpu::VertexFormat::Unorm8x2:
49                     return tint::transform::VertexFormat::kUnorm8x2;
50                 case wgpu::VertexFormat::Unorm8x4:
51                     return tint::transform::VertexFormat::kUnorm8x4;
52                 case wgpu::VertexFormat::Snorm8x2:
53                     return tint::transform::VertexFormat::kSnorm8x2;
54                 case wgpu::VertexFormat::Snorm8x4:
55                     return tint::transform::VertexFormat::kSnorm8x4;
56                 case wgpu::VertexFormat::Uint16x2:
57                     return tint::transform::VertexFormat::kUint16x2;
58                 case wgpu::VertexFormat::Uint16x4:
59                     return tint::transform::VertexFormat::kUint16x4;
60                 case wgpu::VertexFormat::Sint16x2:
61                     return tint::transform::VertexFormat::kSint16x2;
62                 case wgpu::VertexFormat::Sint16x4:
63                     return tint::transform::VertexFormat::kSint16x4;
64                 case wgpu::VertexFormat::Unorm16x2:
65                     return tint::transform::VertexFormat::kUnorm16x2;
66                 case wgpu::VertexFormat::Unorm16x4:
67                     return tint::transform::VertexFormat::kUnorm16x4;
68                 case wgpu::VertexFormat::Snorm16x2:
69                     return tint::transform::VertexFormat::kSnorm16x2;
70                 case wgpu::VertexFormat::Snorm16x4:
71                     return tint::transform::VertexFormat::kSnorm16x4;
72                 case wgpu::VertexFormat::Float16x2:
73                     return tint::transform::VertexFormat::kFloat16x2;
74                 case wgpu::VertexFormat::Float16x4:
75                     return tint::transform::VertexFormat::kFloat16x4;
76                 case wgpu::VertexFormat::Float32:
77                     return tint::transform::VertexFormat::kFloat32;
78                 case wgpu::VertexFormat::Float32x2:
79                     return tint::transform::VertexFormat::kFloat32x2;
80                 case wgpu::VertexFormat::Float32x3:
81                     return tint::transform::VertexFormat::kFloat32x3;
82                 case wgpu::VertexFormat::Float32x4:
83                     return tint::transform::VertexFormat::kFloat32x4;
84                 case wgpu::VertexFormat::Uint32:
85                     return tint::transform::VertexFormat::kUint32;
86                 case wgpu::VertexFormat::Uint32x2:
87                     return tint::transform::VertexFormat::kUint32x2;
88                 case wgpu::VertexFormat::Uint32x3:
89                     return tint::transform::VertexFormat::kUint32x3;
90                 case wgpu::VertexFormat::Uint32x4:
91                     return tint::transform::VertexFormat::kUint32x4;
92                 case wgpu::VertexFormat::Sint32:
93                     return tint::transform::VertexFormat::kSint32;
94                 case wgpu::VertexFormat::Sint32x2:
95                     return tint::transform::VertexFormat::kSint32x2;
96                 case wgpu::VertexFormat::Sint32x3:
97                     return tint::transform::VertexFormat::kSint32x3;
98                 case wgpu::VertexFormat::Sint32x4:
99                     return tint::transform::VertexFormat::kSint32x4;
100 
101                 case wgpu::VertexFormat::Undefined:
102                     break;
103             }
104             UNREACHABLE();
105         }
106 
ToTintVertexStepMode(wgpu::VertexStepMode mode)107         tint::transform::VertexStepMode ToTintVertexStepMode(wgpu::VertexStepMode mode) {
108             switch (mode) {
109                 case wgpu::VertexStepMode::Vertex:
110                     return tint::transform::VertexStepMode::kVertex;
111                 case wgpu::VertexStepMode::Instance:
112                     return tint::transform::VertexStepMode::kInstance;
113             }
114             UNREACHABLE();
115         }
116 
TintPipelineStageToShaderStage(tint::ast::PipelineStage stage)117         ResultOrError<SingleShaderStage> TintPipelineStageToShaderStage(
118             tint::ast::PipelineStage stage) {
119             switch (stage) {
120                 case tint::ast::PipelineStage::kVertex:
121                     return SingleShaderStage::Vertex;
122                 case tint::ast::PipelineStage::kFragment:
123                     return SingleShaderStage::Fragment;
124                 case tint::ast::PipelineStage::kCompute:
125                     return SingleShaderStage::Compute;
126                 case tint::ast::PipelineStage::kNone:
127                     break;
128             }
129             UNREACHABLE();
130         }
131 
TintResourceTypeToBindingInfoType(tint::inspector::ResourceBinding::ResourceType type)132         BindingInfoType TintResourceTypeToBindingInfoType(
133             tint::inspector::ResourceBinding::ResourceType type) {
134             switch (type) {
135                 case tint::inspector::ResourceBinding::ResourceType::kUniformBuffer:
136                 case tint::inspector::ResourceBinding::ResourceType::kStorageBuffer:
137                 case tint::inspector::ResourceBinding::ResourceType::kReadOnlyStorageBuffer:
138                     return BindingInfoType::Buffer;
139                 case tint::inspector::ResourceBinding::ResourceType::kSampler:
140                 case tint::inspector::ResourceBinding::ResourceType::kComparisonSampler:
141                     return BindingInfoType::Sampler;
142                 case tint::inspector::ResourceBinding::ResourceType::kSampledTexture:
143                 case tint::inspector::ResourceBinding::ResourceType::kMultisampledTexture:
144                 case tint::inspector::ResourceBinding::ResourceType::kDepthTexture:
145                 case tint::inspector::ResourceBinding::ResourceType::kDepthMultisampledTexture:
146                     return BindingInfoType::Texture;
147                 case tint::inspector::ResourceBinding::ResourceType::kWriteOnlyStorageTexture:
148                     return BindingInfoType::StorageTexture;
149                 case tint::inspector::ResourceBinding::ResourceType::kExternalTexture:
150                     return BindingInfoType::ExternalTexture;
151 
152                 default:
153                     UNREACHABLE();
154                     return BindingInfoType::Buffer;
155             }
156         }
157 
TintImageFormatToTextureFormat(tint::inspector::ResourceBinding::ImageFormat format)158         wgpu::TextureFormat TintImageFormatToTextureFormat(
159             tint::inspector::ResourceBinding::ImageFormat format) {
160             switch (format) {
161                 case tint::inspector::ResourceBinding::ImageFormat::kR8Unorm:
162                     return wgpu::TextureFormat::R8Unorm;
163                 case tint::inspector::ResourceBinding::ImageFormat::kR8Snorm:
164                     return wgpu::TextureFormat::R8Snorm;
165                 case tint::inspector::ResourceBinding::ImageFormat::kR8Uint:
166                     return wgpu::TextureFormat::R8Uint;
167                 case tint::inspector::ResourceBinding::ImageFormat::kR8Sint:
168                     return wgpu::TextureFormat::R8Sint;
169                 case tint::inspector::ResourceBinding::ImageFormat::kR16Uint:
170                     return wgpu::TextureFormat::R16Uint;
171                 case tint::inspector::ResourceBinding::ImageFormat::kR16Sint:
172                     return wgpu::TextureFormat::R16Sint;
173                 case tint::inspector::ResourceBinding::ImageFormat::kR16Float:
174                     return wgpu::TextureFormat::R16Float;
175                 case tint::inspector::ResourceBinding::ImageFormat::kRg8Unorm:
176                     return wgpu::TextureFormat::RG8Unorm;
177                 case tint::inspector::ResourceBinding::ImageFormat::kRg8Snorm:
178                     return wgpu::TextureFormat::RG8Snorm;
179                 case tint::inspector::ResourceBinding::ImageFormat::kRg8Uint:
180                     return wgpu::TextureFormat::RG8Uint;
181                 case tint::inspector::ResourceBinding::ImageFormat::kRg8Sint:
182                     return wgpu::TextureFormat::RG8Sint;
183                 case tint::inspector::ResourceBinding::ImageFormat::kR32Uint:
184                     return wgpu::TextureFormat::R32Uint;
185                 case tint::inspector::ResourceBinding::ImageFormat::kR32Sint:
186                     return wgpu::TextureFormat::R32Sint;
187                 case tint::inspector::ResourceBinding::ImageFormat::kR32Float:
188                     return wgpu::TextureFormat::R32Float;
189                 case tint::inspector::ResourceBinding::ImageFormat::kRg16Uint:
190                     return wgpu::TextureFormat::RG16Uint;
191                 case tint::inspector::ResourceBinding::ImageFormat::kRg16Sint:
192                     return wgpu::TextureFormat::RG16Sint;
193                 case tint::inspector::ResourceBinding::ImageFormat::kRg16Float:
194                     return wgpu::TextureFormat::RG16Float;
195                 case tint::inspector::ResourceBinding::ImageFormat::kRgba8Unorm:
196                     return wgpu::TextureFormat::RGBA8Unorm;
197                 case tint::inspector::ResourceBinding::ImageFormat::kRgba8UnormSrgb:
198                     return wgpu::TextureFormat::RGBA8UnormSrgb;
199                 case tint::inspector::ResourceBinding::ImageFormat::kRgba8Snorm:
200                     return wgpu::TextureFormat::RGBA8Snorm;
201                 case tint::inspector::ResourceBinding::ImageFormat::kRgba8Uint:
202                     return wgpu::TextureFormat::RGBA8Uint;
203                 case tint::inspector::ResourceBinding::ImageFormat::kRgba8Sint:
204                     return wgpu::TextureFormat::RGBA8Sint;
205                 case tint::inspector::ResourceBinding::ImageFormat::kBgra8Unorm:
206                     return wgpu::TextureFormat::BGRA8Unorm;
207                 case tint::inspector::ResourceBinding::ImageFormat::kBgra8UnormSrgb:
208                     return wgpu::TextureFormat::BGRA8UnormSrgb;
209                 case tint::inspector::ResourceBinding::ImageFormat::kRgb10A2Unorm:
210                     return wgpu::TextureFormat::RGB10A2Unorm;
211                 case tint::inspector::ResourceBinding::ImageFormat::kRg11B10Float:
212                     return wgpu::TextureFormat::RG11B10Ufloat;
213                 case tint::inspector::ResourceBinding::ImageFormat::kRg32Uint:
214                     return wgpu::TextureFormat::RG32Uint;
215                 case tint::inspector::ResourceBinding::ImageFormat::kRg32Sint:
216                     return wgpu::TextureFormat::RG32Sint;
217                 case tint::inspector::ResourceBinding::ImageFormat::kRg32Float:
218                     return wgpu::TextureFormat::RG32Float;
219                 case tint::inspector::ResourceBinding::ImageFormat::kRgba16Uint:
220                     return wgpu::TextureFormat::RGBA16Uint;
221                 case tint::inspector::ResourceBinding::ImageFormat::kRgba16Sint:
222                     return wgpu::TextureFormat::RGBA16Sint;
223                 case tint::inspector::ResourceBinding::ImageFormat::kRgba16Float:
224                     return wgpu::TextureFormat::RGBA16Float;
225                 case tint::inspector::ResourceBinding::ImageFormat::kRgba32Uint:
226                     return wgpu::TextureFormat::RGBA32Uint;
227                 case tint::inspector::ResourceBinding::ImageFormat::kRgba32Sint:
228                     return wgpu::TextureFormat::RGBA32Sint;
229                 case tint::inspector::ResourceBinding::ImageFormat::kRgba32Float:
230                     return wgpu::TextureFormat::RGBA32Float;
231                 case tint::inspector::ResourceBinding::ImageFormat::kNone:
232                     return wgpu::TextureFormat::Undefined;
233             }
234             UNREACHABLE();
235         }
236 
TintTextureDimensionToTextureViewDimension(tint::inspector::ResourceBinding::TextureDimension dim)237         wgpu::TextureViewDimension TintTextureDimensionToTextureViewDimension(
238             tint::inspector::ResourceBinding::TextureDimension dim) {
239             switch (dim) {
240                 case tint::inspector::ResourceBinding::TextureDimension::k1d:
241                     return wgpu::TextureViewDimension::e1D;
242                 case tint::inspector::ResourceBinding::TextureDimension::k2d:
243                     return wgpu::TextureViewDimension::e2D;
244                 case tint::inspector::ResourceBinding::TextureDimension::k2dArray:
245                     return wgpu::TextureViewDimension::e2DArray;
246                 case tint::inspector::ResourceBinding::TextureDimension::k3d:
247                     return wgpu::TextureViewDimension::e3D;
248                 case tint::inspector::ResourceBinding::TextureDimension::kCube:
249                     return wgpu::TextureViewDimension::Cube;
250                 case tint::inspector::ResourceBinding::TextureDimension::kCubeArray:
251                     return wgpu::TextureViewDimension::CubeArray;
252                 case tint::inspector::ResourceBinding::TextureDimension::kNone:
253                     return wgpu::TextureViewDimension::Undefined;
254             }
255             UNREACHABLE();
256         }
257 
TintSampledKindToSampleTypeBit(tint::inspector::ResourceBinding::SampledKind s)258         SampleTypeBit TintSampledKindToSampleTypeBit(
259             tint::inspector::ResourceBinding::SampledKind s) {
260             switch (s) {
261                 case tint::inspector::ResourceBinding::SampledKind::kSInt:
262                     return SampleTypeBit::Sint;
263                 case tint::inspector::ResourceBinding::SampledKind::kUInt:
264                     return SampleTypeBit::Uint;
265                 case tint::inspector::ResourceBinding::SampledKind::kFloat:
266                     return SampleTypeBit::Float | SampleTypeBit::UnfilterableFloat;
267                 case tint::inspector::ResourceBinding::SampledKind::kUnknown:
268                     return SampleTypeBit::None;
269             }
270             UNREACHABLE();
271         }
272 
TintComponentTypeToTextureComponentType(tint::inspector::ComponentType type)273         ResultOrError<wgpu::TextureComponentType> TintComponentTypeToTextureComponentType(
274             tint::inspector::ComponentType type) {
275             switch (type) {
276                 case tint::inspector::ComponentType::kFloat:
277                     return wgpu::TextureComponentType::Float;
278                 case tint::inspector::ComponentType::kSInt:
279                     return wgpu::TextureComponentType::Sint;
280                 case tint::inspector::ComponentType::kUInt:
281                     return wgpu::TextureComponentType::Uint;
282                 case tint::inspector::ComponentType::kUnknown:
283                     return DAWN_VALIDATION_ERROR(
284                         "Attempted to convert 'Unknown' component type from Tint");
285             }
286             UNREACHABLE();
287         }
288 
TintComponentTypeToVertexFormatBaseType(tint::inspector::ComponentType type)289         ResultOrError<VertexFormatBaseType> TintComponentTypeToVertexFormatBaseType(
290             tint::inspector::ComponentType type) {
291             switch (type) {
292                 case tint::inspector::ComponentType::kFloat:
293                     return VertexFormatBaseType::Float;
294                 case tint::inspector::ComponentType::kSInt:
295                     return VertexFormatBaseType::Sint;
296                 case tint::inspector::ComponentType::kUInt:
297                     return VertexFormatBaseType::Uint;
298                 case tint::inspector::ComponentType::kUnknown:
299                     return DAWN_VALIDATION_ERROR(
300                         "Attempted to convert 'Unknown' component type from Tint");
301             }
302             UNREACHABLE();
303         }
304 
TintResourceTypeToBufferBindingType(tint::inspector::ResourceBinding::ResourceType resource_type)305         ResultOrError<wgpu::BufferBindingType> TintResourceTypeToBufferBindingType(
306             tint::inspector::ResourceBinding::ResourceType resource_type) {
307             switch (resource_type) {
308                 case tint::inspector::ResourceBinding::ResourceType::kUniformBuffer:
309                     return wgpu::BufferBindingType::Uniform;
310                 case tint::inspector::ResourceBinding::ResourceType::kStorageBuffer:
311                     return wgpu::BufferBindingType::Storage;
312                 case tint::inspector::ResourceBinding::ResourceType::kReadOnlyStorageBuffer:
313                     return wgpu::BufferBindingType::ReadOnlyStorage;
314                 default:
315                     return DAWN_VALIDATION_ERROR("Attempted to convert non-buffer resource type");
316             }
317             UNREACHABLE();
318         }
319 
TintResourceTypeToStorageTextureAccess(tint::inspector::ResourceBinding::ResourceType resource_type)320         ResultOrError<wgpu::StorageTextureAccess> TintResourceTypeToStorageTextureAccess(
321             tint::inspector::ResourceBinding::ResourceType resource_type) {
322             switch (resource_type) {
323                 case tint::inspector::ResourceBinding::ResourceType::kWriteOnlyStorageTexture:
324                     return wgpu::StorageTextureAccess::WriteOnly;
325                 default:
326                     return DAWN_VALIDATION_ERROR(
327                         "Attempted to convert non-storage texture resource type");
328             }
329             UNREACHABLE();
330         }
331 
TintComponentTypeToInterStageComponentType(tint::inspector::ComponentType type)332         ResultOrError<InterStageComponentType> TintComponentTypeToInterStageComponentType(
333             tint::inspector::ComponentType type) {
334             switch (type) {
335                 case tint::inspector::ComponentType::kFloat:
336                     return InterStageComponentType::Float;
337                 case tint::inspector::ComponentType::kSInt:
338                     return InterStageComponentType::Sint;
339                 case tint::inspector::ComponentType::kUInt:
340                     return InterStageComponentType::Uint;
341                 case tint::inspector::ComponentType::kUnknown:
342                     return DAWN_VALIDATION_ERROR(
343                         "Attempted to convert 'Unknown' component type from Tint");
344             }
345             UNREACHABLE();
346         }
347 
TintCompositionTypeToInterStageComponentCount(tint::inspector::CompositionType type)348         ResultOrError<uint32_t> TintCompositionTypeToInterStageComponentCount(
349             tint::inspector::CompositionType type) {
350             switch (type) {
351                 case tint::inspector::CompositionType::kScalar:
352                     return 1u;
353                 case tint::inspector::CompositionType::kVec2:
354                     return 2u;
355                 case tint::inspector::CompositionType::kVec3:
356                     return 3u;
357                 case tint::inspector::CompositionType::kVec4:
358                     return 4u;
359                 case tint::inspector::CompositionType::kUnknown:
360                     return DAWN_VALIDATION_ERROR(
361                         "Attempt to convert 'Unknown' composition type from Tint");
362             }
363             UNREACHABLE();
364         }
365 
TintInterpolationTypeToInterpolationType(tint::inspector::InterpolationType type)366         ResultOrError<InterpolationType> TintInterpolationTypeToInterpolationType(
367             tint::inspector::InterpolationType type) {
368             switch (type) {
369                 case tint::inspector::InterpolationType::kPerspective:
370                     return InterpolationType::Perspective;
371                 case tint::inspector::InterpolationType::kLinear:
372                     return InterpolationType::Linear;
373                 case tint::inspector::InterpolationType::kFlat:
374                     return InterpolationType::Flat;
375                 case tint::inspector::InterpolationType::kUnknown:
376                     return DAWN_VALIDATION_ERROR(
377                         "Attempted to convert 'Unknown' interpolation type from Tint");
378             }
379             UNREACHABLE();
380         }
381 
TintInterpolationSamplingToInterpolationSamplingType(tint::inspector::InterpolationSampling type)382         ResultOrError<InterpolationSampling> TintInterpolationSamplingToInterpolationSamplingType(
383             tint::inspector::InterpolationSampling type) {
384             switch (type) {
385                 case tint::inspector::InterpolationSampling::kNone:
386                     return InterpolationSampling::None;
387                 case tint::inspector::InterpolationSampling::kCenter:
388                     return InterpolationSampling::Center;
389                 case tint::inspector::InterpolationSampling::kCentroid:
390                     return InterpolationSampling::Centroid;
391                 case tint::inspector::InterpolationSampling::kSample:
392                     return InterpolationSampling::Sample;
393                 case tint::inspector::InterpolationSampling::kUnknown:
394                     return DAWN_VALIDATION_ERROR(
395                         "Attempted to convert 'Unknown' interpolation sampling type from Tint");
396             }
397             UNREACHABLE();
398         }
399 
FromTintOverridableConstantType(tint::inspector::OverridableConstant::Type type)400         EntryPointMetadata::OverridableConstant::Type FromTintOverridableConstantType(
401             tint::inspector::OverridableConstant::Type type) {
402             switch (type) {
403                 case tint::inspector::OverridableConstant::Type::kBool:
404                     return EntryPointMetadata::OverridableConstant::Type::Boolean;
405                 case tint::inspector::OverridableConstant::Type::kFloat32:
406                     return EntryPointMetadata::OverridableConstant::Type::Float32;
407                 case tint::inspector::OverridableConstant::Type::kInt32:
408                     return EntryPointMetadata::OverridableConstant::Type::Int32;
409                 case tint::inspector::OverridableConstant::Type::kUint32:
410                     return EntryPointMetadata::OverridableConstant::Type::Uint32;
411                 default:
412                     UNREACHABLE();
413             }
414         }
415 
ParseWGSL(const tint::Source::File * file,OwnedCompilationMessages * outMessages)416         ResultOrError<tint::Program> ParseWGSL(const tint::Source::File* file,
417                                                OwnedCompilationMessages* outMessages) {
418             tint::Program program = tint::reader::wgsl::Parse(file);
419             if (outMessages != nullptr) {
420                 outMessages->AddMessages(program.Diagnostics());
421             }
422             if (!program.IsValid()) {
423                 return DAWN_FORMAT_VALIDATION_ERROR(
424                     "Tint WGSL reader failure:\nParser: %s\nShader:\n%s\n",
425                     program.Diagnostics().str(), file->content.data);
426             }
427 
428             return std::move(program);
429         }
430 
ParseSPIRV(const std::vector<uint32_t> & spirv,OwnedCompilationMessages * outMessages)431         ResultOrError<tint::Program> ParseSPIRV(const std::vector<uint32_t>& spirv,
432                                                 OwnedCompilationMessages* outMessages) {
433             tint::Program program = tint::reader::spirv::Parse(spirv);
434             if (outMessages != nullptr) {
435                 outMessages->AddMessages(program.Diagnostics());
436             }
437             if (!program.IsValid()) {
438                 return DAWN_FORMAT_VALIDATION_ERROR("Tint SPIR-V reader failure:\nParser: %s\n",
439                                                     program.Diagnostics().str());
440             }
441 
442             return std::move(program);
443         }
444 
GetBindGroupMinBufferSizes(const BindingGroupInfoMap & shaderBindings,const BindGroupLayoutBase * layout)445         std::vector<uint64_t> GetBindGroupMinBufferSizes(const BindingGroupInfoMap& shaderBindings,
446                                                          const BindGroupLayoutBase* layout) {
447             std::vector<uint64_t> requiredBufferSizes(layout->GetUnverifiedBufferCount());
448             uint32_t packedIdx = 0;
449 
450             for (BindingIndex bindingIndex{0}; bindingIndex < layout->GetBufferCount();
451                  ++bindingIndex) {
452                 const BindingInfo& bindingInfo = layout->GetBindingInfo(bindingIndex);
453                 if (bindingInfo.buffer.minBindingSize != 0) {
454                     // Skip bindings that have minimum buffer size set in the layout
455                     continue;
456                 }
457 
458                 ASSERT(packedIdx < requiredBufferSizes.size());
459                 const auto& shaderInfo = shaderBindings.find(bindingInfo.binding);
460                 if (shaderInfo != shaderBindings.end()) {
461                     requiredBufferSizes[packedIdx] = shaderInfo->second.buffer.minBindingSize;
462                 } else {
463                     // We have to include buffers if they are included in the bind group's
464                     // packed vector. We don't actually need to check these at draw time, so
465                     // if this is a problem in the future we can optimize it further.
466                     requiredBufferSizes[packedIdx] = 0;
467                 }
468                 ++packedIdx;
469             }
470 
471             return requiredBufferSizes;
472         }
473 
ValidateCompatibilityOfSingleBindingWithLayout(const DeviceBase * device,const BindGroupLayoutBase * layout,SingleShaderStage entryPointStage,BindingNumber bindingNumber,const ShaderBindingInfo & shaderInfo)474         MaybeError ValidateCompatibilityOfSingleBindingWithLayout(
475             const DeviceBase* device,
476             const BindGroupLayoutBase* layout,
477             SingleShaderStage entryPointStage,
478             BindingNumber bindingNumber,
479             const ShaderBindingInfo& shaderInfo) {
480             const BindGroupLayoutBase::BindingMap& layoutBindings = layout->GetBindingMap();
481 
482             const auto& bindingIt = layoutBindings.find(bindingNumber);
483             DAWN_INVALID_IF(bindingIt == layoutBindings.end(), "Binding doesn't exist in %s.",
484                             layout);
485 
486             BindingIndex bindingIndex(bindingIt->second);
487             const BindingInfo& layoutInfo = layout->GetBindingInfo(bindingIndex);
488 
489             // TODO(dawn:563): Provide info about the binding types.
490             DAWN_INVALID_IF(layoutInfo.bindingType != shaderInfo.bindingType,
491                             "Binding type (buffer vs. texture vs. sampler) doesn't match the type "
492                             "in the layout.");
493 
494             // TODO(dawn:563): Provide info about the visibility.
495             DAWN_INVALID_IF(
496                 (layoutInfo.visibility & StageBit(entryPointStage)) == 0,
497                 "Entry point's stage is not in the binding visibility in the layout (%s)",
498                 layoutInfo.visibility);
499 
500             switch (layoutInfo.bindingType) {
501                 case BindingInfoType::Texture: {
502                     DAWN_INVALID_IF(
503                         layoutInfo.texture.multisampled != shaderInfo.texture.multisampled,
504                         "Binding multisampled flag (%u) doesn't match the layout's multisampled "
505                         "flag (%u)",
506                         layoutInfo.texture.multisampled, shaderInfo.texture.multisampled);
507 
508                     // TODO(dawn:563): Provide info about the sample types.
509                     DAWN_INVALID_IF((SampleTypeToSampleTypeBit(layoutInfo.texture.sampleType) &
510                                      shaderInfo.texture.compatibleSampleTypes) == 0,
511                                     "The sample type in the shader is not compatible with the "
512                                     "sample type of the layout.");
513 
514                     DAWN_INVALID_IF(
515                         layoutInfo.texture.viewDimension != shaderInfo.texture.viewDimension,
516                         "The shader's binding dimension (%s) doesn't match the shader's binding "
517                         "dimension (%s).",
518                         layoutInfo.texture.viewDimension, shaderInfo.texture.viewDimension);
519                     break;
520                 }
521 
522                 case BindingInfoType::StorageTexture: {
523                     ASSERT(layoutInfo.storageTexture.format != wgpu::TextureFormat::Undefined);
524                     ASSERT(shaderInfo.storageTexture.format != wgpu::TextureFormat::Undefined);
525 
526                     DAWN_INVALID_IF(
527                         layoutInfo.storageTexture.access != shaderInfo.storageTexture.access,
528                         "The layout's binding access (%s) isn't compatible with the shader's "
529                         "binding access (%s).",
530                         layoutInfo.storageTexture.access, shaderInfo.storageTexture.access);
531 
532                     DAWN_INVALID_IF(
533                         layoutInfo.storageTexture.format != shaderInfo.storageTexture.format,
534                         "The layout's binding format (%s) doesn't match the shader's binding "
535                         "format (%s).",
536                         layoutInfo.storageTexture.format, shaderInfo.storageTexture.format);
537 
538                     DAWN_INVALID_IF(layoutInfo.storageTexture.viewDimension !=
539                                         shaderInfo.storageTexture.viewDimension,
540                                     "The layout's binding dimension (%s) doesn't match the "
541                                     "shader's binding dimension (%s).",
542                                     layoutInfo.storageTexture.viewDimension,
543                                     shaderInfo.storageTexture.viewDimension);
544                     break;
545                 }
546 
547                 case BindingInfoType::ExternalTexture: {
548                     // Nothing to validate! (yet?)
549                     break;
550                 }
551 
552                 case BindingInfoType::Buffer: {
553                     // Binding mismatch between shader and bind group is invalid. For example, a
554                     // writable binding in the shader with a readonly storage buffer in the bind
555                     // group layout is invalid. However, a readonly binding in the shader with a
556                     // writable storage buffer in the bind group layout is valid, a storage
557                     // binding in the shader with an internal storage buffer in the bind group
558                     // layout is also valid.
559                     bool validBindingConversion =
560                         (layoutInfo.buffer.type == wgpu::BufferBindingType::Storage &&
561                          shaderInfo.buffer.type == wgpu::BufferBindingType::ReadOnlyStorage) ||
562                         (layoutInfo.buffer.type == kInternalStorageBufferBinding &&
563                          shaderInfo.buffer.type == wgpu::BufferBindingType::Storage);
564 
565                     DAWN_INVALID_IF(
566                         layoutInfo.buffer.type != shaderInfo.buffer.type && !validBindingConversion,
567                         "The buffer type in the shader (%s) is not compatible with the type in the "
568                         "layout (%s).",
569                         shaderInfo.buffer.type, layoutInfo.buffer.type);
570 
571                     DAWN_INVALID_IF(
572                         layoutInfo.buffer.minBindingSize != 0 &&
573                             shaderInfo.buffer.minBindingSize > layoutInfo.buffer.minBindingSize,
574                         "The shader uses more bytes of the buffer (%u) than the layout's "
575                         "minBindingSize (%u).",
576                         shaderInfo.buffer.minBindingSize, layoutInfo.buffer.minBindingSize);
577                     break;
578                 }
579 
580                 case BindingInfoType::Sampler:
581                     DAWN_INVALID_IF(
582                         (layoutInfo.sampler.type == wgpu::SamplerBindingType::Comparison) !=
583                             shaderInfo.sampler.isComparison,
584                         "The sampler type in the shader (comparison: %u) doesn't match the type in "
585                         "the layout (comparison: %u).",
586                         shaderInfo.sampler.isComparison,
587                         layoutInfo.sampler.type == wgpu::SamplerBindingType::Comparison);
588                     break;
589             }
590 
591             return {};
592         }
ValidateCompatibilityWithBindGroupLayout(DeviceBase * device,BindGroupIndex group,const EntryPointMetadata & entryPoint,const BindGroupLayoutBase * layout)593         MaybeError ValidateCompatibilityWithBindGroupLayout(DeviceBase* device,
594                                                             BindGroupIndex group,
595                                                             const EntryPointMetadata& entryPoint,
596                                                             const BindGroupLayoutBase* layout) {
597             // Iterate over all bindings used by this group in the shader, and find the
598             // corresponding binding in the BindGroupLayout, if it exists.
599             for (const auto& it : entryPoint.bindings[group]) {
600                 DAWN_TRY_CONTEXT(ValidateCompatibilityOfSingleBindingWithLayout(
601                                      device, layout, entryPoint.stage, it.first, it.second),
602                                  "validating that the entry-point's declaration for [[group(%u), "
603                                  "binding(%u)]] matches %s",
604                                  static_cast<uint32_t>(group), static_cast<uint32_t>(it.first),
605                                  layout);
606             }
607 
608             return {};
609         }
610 
ReflectShaderUsingTint(const DeviceBase * device,const tint::Program * program)611         ResultOrError<EntryPointMetadataTable> ReflectShaderUsingTint(
612             const DeviceBase* device,
613             const tint::Program* program) {
614             ASSERT(program->IsValid());
615 
616             const CombinedLimits& limits = device->GetLimits();
617 
618             EntryPointMetadataTable result;
619 
620             tint::inspector::Inspector inspector(program);
621             auto entryPoints = inspector.GetEntryPoints();
622             DAWN_INVALID_IF(inspector.has_error(), "Tint Reflection failure: Inspector: %s\n",
623                             inspector.error());
624 
625             // TODO(dawn:563): use DAWN_TRY_CONTEXT to output the name of the entry point we're
626             // reflecting.
627             constexpr uint32_t kMaxInterStageShaderLocation = kMaxInterStageShaderVariables - 1;
628             for (auto& entryPoint : entryPoints) {
629                 ASSERT(result.count(entryPoint.name) == 0);
630 
631                 auto metadata = std::make_unique<EntryPointMetadata>();
632 
633                 if (!entryPoint.overridable_constants.empty()) {
634                     DAWN_INVALID_IF(device->IsToggleEnabled(Toggle::DisallowUnsafeAPIs),
635                                     "Pipeline overridable constants are disallowed because they "
636                                     "are partially implemented.");
637 
638                     const auto& name2Id = inspector.GetConstantNameToIdMap();
639                     const auto& id2Scalar = inspector.GetConstantIDs();
640 
641                     for (auto& c : entryPoint.overridable_constants) {
642                         uint32_t id = name2Id.at(c.name);
643                         OverridableConstantScalar defaultValue;
644                         if (c.is_initialized) {
645                             // if it is initialized, the scalar must exist
646                             const auto& scalar = id2Scalar.at(id);
647                             if (scalar.IsBool()) {
648                                 defaultValue.b = scalar.AsBool();
649                             } else if (scalar.IsU32()) {
650                                 defaultValue.u32 = scalar.AsU32();
651                             } else if (scalar.IsI32()) {
652                                 defaultValue.i32 = scalar.AsI32();
653                             } else if (scalar.IsFloat()) {
654                                 defaultValue.f32 = scalar.AsFloat();
655                             } else {
656                                 UNREACHABLE();
657                             }
658                         }
659                         EntryPointMetadata::OverridableConstant constant = {
660                             id, FromTintOverridableConstantType(c.type), c.is_initialized,
661                             defaultValue};
662 
663                         std::string identifier =
664                             c.is_numeric_id_specified ? std::to_string(constant.id) : c.name;
665                         metadata->overridableConstants[identifier] = constant;
666 
667                         if (!c.is_initialized) {
668                             auto it = metadata->uninitializedOverridableConstants.emplace(
669                                 std::move(identifier));
670                             // The insertion should have taken place
671                             ASSERT(it.second);
672                         } else {
673                             auto it = metadata->initializedOverridableConstants.emplace(
674                                 std::move(identifier));
675                             // The insertion should have taken place
676                             ASSERT(it.second);
677                         }
678                     }
679                 }
680 
681                 DAWN_TRY_ASSIGN(metadata->stage, TintPipelineStageToShaderStage(entryPoint.stage));
682 
683                 if (metadata->stage == SingleShaderStage::Compute) {
684                     DAWN_INVALID_IF(
685                         entryPoint.workgroup_size_x > limits.v1.maxComputeWorkgroupSizeX ||
686                             entryPoint.workgroup_size_y > limits.v1.maxComputeWorkgroupSizeY ||
687                             entryPoint.workgroup_size_z > limits.v1.maxComputeWorkgroupSizeZ,
688                         "Entry-point uses workgroup_size(%u, %u, %u) that exceeds the "
689                         "maximum allowed (%u, %u, %u).",
690                         entryPoint.workgroup_size_x, entryPoint.workgroup_size_y,
691                         entryPoint.workgroup_size_z, limits.v1.maxComputeWorkgroupSizeX,
692                         limits.v1.maxComputeWorkgroupSizeY, limits.v1.maxComputeWorkgroupSizeZ);
693 
694                     // Dimensions have already been validated against their individual limits above.
695                     // Cast to uint64_t to avoid overflow in this multiplication.
696                     uint64_t numInvocations = static_cast<uint64_t>(entryPoint.workgroup_size_x) *
697                                               entryPoint.workgroup_size_y *
698                                               entryPoint.workgroup_size_z;
699                     DAWN_INVALID_IF(numInvocations > limits.v1.maxComputeInvocationsPerWorkgroup,
700                                     "The total number of workgroup invocations (%u) exceeds the "
701                                     "maximum allowed (%u).",
702                                     numInvocations, limits.v1.maxComputeInvocationsPerWorkgroup);
703 
704                     const size_t workgroupStorageSize =
705                         inspector.GetWorkgroupStorageSize(entryPoint.name);
706                     DAWN_INVALID_IF(workgroupStorageSize > limits.v1.maxComputeWorkgroupStorageSize,
707                                     "The total use of workgroup storage (%u bytes) is larger than "
708                                     "the maximum allowed (%u bytes).",
709                                     workgroupStorageSize, limits.v1.maxComputeWorkgroupStorageSize);
710 
711                     metadata->localWorkgroupSize.x = entryPoint.workgroup_size_x;
712                     metadata->localWorkgroupSize.y = entryPoint.workgroup_size_y;
713                     metadata->localWorkgroupSize.z = entryPoint.workgroup_size_z;
714 
715                     metadata->usesNumWorkgroups = entryPoint.num_workgroups_used;
716                 }
717 
718                 if (metadata->stage == SingleShaderStage::Vertex) {
719                     for (const auto& inputVar : entryPoint.input_variables) {
720                         DAWN_INVALID_IF(
721                             !inputVar.has_location_decoration,
722                             "Vertex input variable \"%s\" doesn't have a location decoration.",
723                             inputVar.name);
724 
725                         uint32_t unsanitizedLocation = inputVar.location_decoration;
726                         DAWN_INVALID_IF(unsanitizedLocation >= kMaxVertexAttributes,
727                                         "Vertex input variable \"%s\" has a location (%u) that "
728                                         "exceeds the maximum (%u)",
729                                         inputVar.name, unsanitizedLocation, kMaxVertexAttributes);
730                         VertexAttributeLocation location(static_cast<uint8_t>(unsanitizedLocation));
731 
732                         DAWN_TRY_ASSIGN(
733                             metadata->vertexInputBaseTypes[location],
734                             TintComponentTypeToVertexFormatBaseType(inputVar.component_type));
735                         metadata->usedVertexInputs.set(location);
736                     }
737 
738                     // [[position]] must be declared in a vertex shader but is not exposed as an
739                     // output variable by Tint so we directly add its components to the total.
740                     uint32_t totalInterStageShaderComponents = 4;
741                     for (const auto& outputVar : entryPoint.output_variables) {
742                         DAWN_INVALID_IF(
743                             !outputVar.has_location_decoration,
744                             "Vertex ouput variable \"%s\" doesn't have a location decoration.",
745                             outputVar.name);
746 
747                         uint32_t location = outputVar.location_decoration;
748                         DAWN_INVALID_IF(location > kMaxInterStageShaderLocation,
749                                         "Vertex output variable \"%s\" has a location (%u) that "
750                                         "exceeds the maximum (%u).",
751                                         outputVar.name, location, kMaxInterStageShaderLocation);
752 
753                         metadata->usedInterStageVariables.set(location);
754                         DAWN_TRY_ASSIGN(
755                             metadata->interStageVariables[location].baseType,
756                             TintComponentTypeToInterStageComponentType(outputVar.component_type));
757                         DAWN_TRY_ASSIGN(metadata->interStageVariables[location].componentCount,
758                                         TintCompositionTypeToInterStageComponentCount(
759                                             outputVar.composition_type));
760                         DAWN_TRY_ASSIGN(
761                             metadata->interStageVariables[location].interpolationType,
762                             TintInterpolationTypeToInterpolationType(outputVar.interpolation_type));
763                         DAWN_TRY_ASSIGN(
764                             metadata->interStageVariables[location].interpolationSampling,
765                             TintInterpolationSamplingToInterpolationSamplingType(
766                                 outputVar.interpolation_sampling));
767 
768                         totalInterStageShaderComponents +=
769                             metadata->interStageVariables[location].componentCount;
770                     }
771 
772                     DAWN_INVALID_IF(
773                         totalInterStageShaderComponents > kMaxInterStageShaderComponents,
774                         "Total vertex output components count (%u) exceeds the maximum (%u).",
775                         totalInterStageShaderComponents, kMaxInterStageShaderComponents);
776                 }
777 
778                 if (metadata->stage == SingleShaderStage::Fragment) {
779                     uint32_t totalInterStageShaderComponents = 0;
780                     for (const auto& inputVar : entryPoint.input_variables) {
781                         DAWN_INVALID_IF(
782                             !inputVar.has_location_decoration,
783                             "Fragment input variable \"%s\" doesn't have a location decoration.",
784                             inputVar.name);
785 
786                         uint32_t location = inputVar.location_decoration;
787                         DAWN_INVALID_IF(location > kMaxInterStageShaderLocation,
788                                         "Fragment input variable \"%s\" has a location (%u) that "
789                                         "exceeds the maximum (%u).",
790                                         inputVar.name, location, kMaxInterStageShaderLocation);
791 
792                         metadata->usedInterStageVariables.set(location);
793                         DAWN_TRY_ASSIGN(
794                             metadata->interStageVariables[location].baseType,
795                             TintComponentTypeToInterStageComponentType(inputVar.component_type));
796                         DAWN_TRY_ASSIGN(metadata->interStageVariables[location].componentCount,
797                                         TintCompositionTypeToInterStageComponentCount(
798                                             inputVar.composition_type));
799                         DAWN_TRY_ASSIGN(
800                             metadata->interStageVariables[location].interpolationType,
801                             TintInterpolationTypeToInterpolationType(inputVar.interpolation_type));
802                         DAWN_TRY_ASSIGN(
803                             metadata->interStageVariables[location].interpolationSampling,
804                             TintInterpolationSamplingToInterpolationSamplingType(
805                                 inputVar.interpolation_sampling));
806 
807                         totalInterStageShaderComponents +=
808                             metadata->interStageVariables[location].componentCount;
809                     }
810 
811                     if (entryPoint.front_facing_used) {
812                         totalInterStageShaderComponents += 1;
813                     }
814                     if (entryPoint.input_sample_mask_used) {
815                         totalInterStageShaderComponents += 1;
816                     }
817                     if (entryPoint.sample_index_used) {
818                         totalInterStageShaderComponents += 1;
819                     }
820                     if (entryPoint.input_position_used) {
821                         totalInterStageShaderComponents += 4;
822                     }
823 
824                     DAWN_INVALID_IF(
825                         totalInterStageShaderComponents > kMaxInterStageShaderComponents,
826                         "Total fragment input components count (%u) exceeds the maximum (%u).",
827                         totalInterStageShaderComponents, kMaxInterStageShaderComponents);
828 
829                     for (const auto& outputVar : entryPoint.output_variables) {
830                         DAWN_INVALID_IF(
831                             !outputVar.has_location_decoration,
832                             "Fragment input variable \"%s\" doesn't have a location decoration.",
833                             outputVar.name);
834 
835                         uint32_t unsanitizedAttachment = outputVar.location_decoration;
836                         DAWN_INVALID_IF(unsanitizedAttachment >= kMaxColorAttachments,
837                                         "Fragment output variable \"%s\" has a location (%u) that "
838                                         "exceeds the maximum (%u).",
839                                         outputVar.name, unsanitizedAttachment,
840                                         kMaxColorAttachments);
841                         ColorAttachmentIndex attachment(
842                             static_cast<uint8_t>(unsanitizedAttachment));
843 
844                         DAWN_TRY_ASSIGN(
845                             metadata->fragmentOutputVariables[attachment].baseType,
846                             TintComponentTypeToTextureComponentType(outputVar.component_type));
847                         uint32_t componentCount;
848                         DAWN_TRY_ASSIGN(componentCount,
849                                         TintCompositionTypeToInterStageComponentCount(
850                                             outputVar.composition_type));
851                         // componentCount should be no larger than 4u
852                         ASSERT(componentCount <= 4u);
853                         metadata->fragmentOutputVariables[attachment].componentCount =
854                             componentCount;
855                         metadata->fragmentOutputsWritten.set(attachment);
856                     }
857                 }
858 
859                 for (const tint::inspector::ResourceBinding& resource :
860                      inspector.GetResourceBindings(entryPoint.name)) {
861                     DAWN_INVALID_IF(resource.bind_group >= kMaxBindGroups,
862                                     "The entry-point uses a binding with a group decoration (%u) "
863                                     "that exceeds the maximum (%u).",
864                                     resource.bind_group, kMaxBindGroups);
865 
866                     BindingNumber bindingNumber(resource.binding);
867                     BindGroupIndex bindGroupIndex(resource.bind_group);
868 
869                     const auto& it = metadata->bindings[bindGroupIndex].emplace(
870                         bindingNumber, ShaderBindingInfo{});
871                     DAWN_INVALID_IF(
872                         !it.second,
873                         "Entry-point has a duplicate binding for (group:%u, binding:%u).",
874                         resource.binding, resource.bind_group);
875 
876                     ShaderBindingInfo* info = &it.first->second;
877                     info->bindingType = TintResourceTypeToBindingInfoType(resource.resource_type);
878 
879                     switch (info->bindingType) {
880                         case BindingInfoType::Buffer:
881                             info->buffer.minBindingSize = resource.size_no_padding;
882                             DAWN_TRY_ASSIGN(info->buffer.type, TintResourceTypeToBufferBindingType(
883                                                                    resource.resource_type));
884                             break;
885                         case BindingInfoType::Sampler:
886                             switch (resource.resource_type) {
887                                 case tint::inspector::ResourceBinding::ResourceType::kSampler:
888                                     info->sampler.isComparison = false;
889                                     break;
890                                 case tint::inspector::ResourceBinding::ResourceType::
891                                     kComparisonSampler:
892                                     info->sampler.isComparison = true;
893                                     break;
894                                 default:
895                                     UNREACHABLE();
896                             }
897                             break;
898                         case BindingInfoType::Texture:
899                             info->texture.viewDimension =
900                                 TintTextureDimensionToTextureViewDimension(resource.dim);
901                             if (resource.resource_type ==
902                                     tint::inspector::ResourceBinding::ResourceType::kDepthTexture ||
903                                 resource.resource_type ==
904                                     tint::inspector::ResourceBinding::ResourceType::
905                                         kDepthMultisampledTexture) {
906                                 info->texture.compatibleSampleTypes = SampleTypeBit::Depth;
907                             } else {
908                                 info->texture.compatibleSampleTypes =
909                                     TintSampledKindToSampleTypeBit(resource.sampled_kind);
910                             }
911                             info->texture.multisampled =
912                                 resource.resource_type == tint::inspector::ResourceBinding::
913                                                               ResourceType::kMultisampledTexture ||
914                                 resource.resource_type ==
915                                     tint::inspector::ResourceBinding::ResourceType::
916                                         kDepthMultisampledTexture;
917 
918                             break;
919                         case BindingInfoType::StorageTexture:
920                             DAWN_TRY_ASSIGN(
921                                 info->storageTexture.access,
922                                 TintResourceTypeToStorageTextureAccess(resource.resource_type));
923                             info->storageTexture.format =
924                                 TintImageFormatToTextureFormat(resource.image_format);
925                             info->storageTexture.viewDimension =
926                                 TintTextureDimensionToTextureViewDimension(resource.dim);
927 
928                             break;
929                         case BindingInfoType::ExternalTexture:
930                             break;
931                         default:
932                             return DAWN_VALIDATION_ERROR("Unknown binding type in Shader");
933                     }
934                 }
935 
936                 std::vector<tint::inspector::SamplerTexturePair> samplerTextureUses =
937                     inspector.GetSamplerTextureUses(entryPoint.name);
938                 metadata->samplerTexturePairs.reserve(samplerTextureUses.size());
939                 std::transform(
940                     samplerTextureUses.begin(), samplerTextureUses.end(),
941                     std::back_inserter(metadata->samplerTexturePairs),
942                     [](const tint::inspector::SamplerTexturePair& pair) {
943                         EntryPointMetadata::SamplerTexturePair result;
944                         result.sampler = {BindGroupIndex(pair.sampler_binding_point.group),
945                                           BindingNumber(pair.sampler_binding_point.binding)};
946                         result.texture = {BindGroupIndex(pair.texture_binding_point.group),
947                                           BindingNumber(pair.texture_binding_point.binding)};
948                         return result;
949                     });
950 
951                 result[entryPoint.name] = std::move(metadata);
952             }
953             return std::move(result);
954         }
955     }  // anonymous namespace
956 
957     ShaderModuleParseResult::ShaderModuleParseResult() = default;
958     ShaderModuleParseResult::~ShaderModuleParseResult() = default;
959 
960     ShaderModuleParseResult::ShaderModuleParseResult(ShaderModuleParseResult&& rhs) = default;
961 
962     ShaderModuleParseResult& ShaderModuleParseResult::operator=(ShaderModuleParseResult&& rhs) =
963         default;
964 
HasParsedShader() const965     bool ShaderModuleParseResult::HasParsedShader() const {
966         return tintProgram != nullptr;
967     }
968 
969     // TintSource is a PIMPL container for a tint::Source::File, which needs to be kept alive for as
970     // long as tint diagnostics are inspected / printed.
971     class TintSource {
972       public:
973         template <typename... ARGS>
TintSource(ARGS &&...args)974         TintSource(ARGS&&... args) : file(std::forward<ARGS>(args)...) {
975         }
976 
977         tint::Source::File file;
978     };
979 
ValidateShaderModuleDescriptor(DeviceBase * device,const ShaderModuleDescriptor * descriptor,ShaderModuleParseResult * parseResult,OwnedCompilationMessages * outMessages)980     MaybeError ValidateShaderModuleDescriptor(DeviceBase* device,
981                                               const ShaderModuleDescriptor* descriptor,
982                                               ShaderModuleParseResult* parseResult,
983                                               OwnedCompilationMessages* outMessages) {
984         ASSERT(parseResult != nullptr);
985 
986         const ChainedStruct* chainedDescriptor = descriptor->nextInChain;
987         DAWN_INVALID_IF(chainedDescriptor == nullptr,
988                         "Shader module descriptor missing chained descriptor");
989 
990         // For now only a single SPIRV or WGSL subdescriptor is allowed.
991         DAWN_TRY(ValidateSingleSType(chainedDescriptor, wgpu::SType::ShaderModuleSPIRVDescriptor,
992                                      wgpu::SType::ShaderModuleWGSLDescriptor));
993 
994         ScopedTintICEHandler scopedICEHandler(device);
995 
996         const ShaderModuleSPIRVDescriptor* spirvDesc = nullptr;
997         FindInChain(chainedDescriptor, &spirvDesc);
998         const ShaderModuleWGSLDescriptor* wgslDesc = nullptr;
999         FindInChain(chainedDescriptor, &wgslDesc);
1000 
1001         // We have a temporary toggle to force the SPIRV ingestion to go through a WGSL
1002         // intermediate step. It is done by switching the spirvDesc for a wgslDesc below.
1003         ShaderModuleWGSLDescriptor newWgslDesc;
1004         std::string newWgslCode;
1005         if (spirvDesc && device->IsToggleEnabled(Toggle::ForceWGSLStep)) {
1006             std::vector<uint32_t> spirv(spirvDesc->code, spirvDesc->code + spirvDesc->codeSize);
1007             tint::Program program;
1008             DAWN_TRY_ASSIGN(program, ParseSPIRV(spirv, outMessages));
1009 
1010             tint::writer::wgsl::Options options;
1011             auto result = tint::writer::wgsl::Generate(&program, options);
1012             DAWN_INVALID_IF(!result.success, "Tint WGSL failure: Generator: %s", result.error);
1013 
1014             newWgslCode = std::move(result.wgsl);
1015             newWgslDesc.source = newWgslCode.c_str();
1016 
1017             spirvDesc = nullptr;
1018             wgslDesc = &newWgslDesc;
1019         }
1020 
1021         if (spirvDesc) {
1022             DAWN_INVALID_IF(device->IsToggleEnabled(Toggle::DisallowSpirv),
1023                             "SPIR-V is disallowed.");
1024 
1025             std::vector<uint32_t> spirv(spirvDesc->code, spirvDesc->code + spirvDesc->codeSize);
1026             tint::Program program;
1027             DAWN_TRY_ASSIGN(program, ParseSPIRV(spirv, outMessages));
1028             parseResult->tintProgram = std::make_unique<tint::Program>(std::move(program));
1029         } else if (wgslDesc) {
1030             auto tintSource = std::make_unique<TintSource>("", wgslDesc->source);
1031 
1032             if (device->IsToggleEnabled(Toggle::DumpShaders)) {
1033                 std::ostringstream dumpedMsg;
1034                 dumpedMsg << "// Dumped WGSL:" << std::endl << wgslDesc->source;
1035                 device->EmitLog(WGPULoggingType_Info, dumpedMsg.str().c_str());
1036             }
1037 
1038             tint::Program program;
1039             DAWN_TRY_ASSIGN(program, ParseWGSL(&tintSource->file, outMessages));
1040             parseResult->tintProgram = std::make_unique<tint::Program>(std::move(program));
1041             parseResult->tintSource = std::move(tintSource);
1042         }
1043 
1044         return {};
1045     }
1046 
ComputeRequiredBufferSizesForLayout(const EntryPointMetadata & entryPoint,const PipelineLayoutBase * layout)1047     RequiredBufferSizes ComputeRequiredBufferSizesForLayout(const EntryPointMetadata& entryPoint,
1048                                                             const PipelineLayoutBase* layout) {
1049         RequiredBufferSizes bufferSizes;
1050         for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
1051             bufferSizes[group] = GetBindGroupMinBufferSizes(entryPoint.bindings[group],
1052                                                             layout->GetBindGroupLayout(group));
1053         }
1054 
1055         return bufferSizes;
1056     }
1057 
RunTransforms(tint::transform::Transform * transform,const tint::Program * program,const tint::transform::DataMap & inputs,tint::transform::DataMap * outputs,OwnedCompilationMessages * outMessages)1058     ResultOrError<tint::Program> RunTransforms(tint::transform::Transform* transform,
1059                                                const tint::Program* program,
1060                                                const tint::transform::DataMap& inputs,
1061                                                tint::transform::DataMap* outputs,
1062                                                OwnedCompilationMessages* outMessages) {
1063         tint::transform::Output output = transform->Run(program, inputs);
1064         if (outMessages != nullptr) {
1065             outMessages->AddMessages(output.program.Diagnostics());
1066         }
1067         DAWN_INVALID_IF(!output.program.IsValid(), "Tint program failure: %s\n",
1068                         output.program.Diagnostics().str());
1069         if (outputs != nullptr) {
1070             *outputs = std::move(output.data);
1071         }
1072         return std::move(output.program);
1073     }
1074 
AddVertexPullingTransformConfig(const RenderPipelineBase & renderPipeline,const std::string & entryPoint,BindGroupIndex pullingBufferBindingSet,tint::transform::DataMap * transformInputs)1075     void AddVertexPullingTransformConfig(const RenderPipelineBase& renderPipeline,
1076                                          const std::string& entryPoint,
1077                                          BindGroupIndex pullingBufferBindingSet,
1078                                          tint::transform::DataMap* transformInputs) {
1079         tint::transform::VertexPulling::Config cfg;
1080         cfg.entry_point_name = entryPoint;
1081         cfg.pulling_group = static_cast<uint32_t>(pullingBufferBindingSet);
1082 
1083         cfg.vertex_state.resize(renderPipeline.GetVertexBufferCount());
1084         for (VertexBufferSlot slot : IterateBitSet(renderPipeline.GetVertexBufferSlotsUsed())) {
1085             const VertexBufferInfo& dawnInfo = renderPipeline.GetVertexBuffer(slot);
1086             tint::transform::VertexBufferLayoutDescriptor* tintInfo =
1087                 &cfg.vertex_state[static_cast<uint8_t>(slot)];
1088 
1089             tintInfo->array_stride = dawnInfo.arrayStride;
1090             tintInfo->step_mode = ToTintVertexStepMode(dawnInfo.stepMode);
1091         }
1092 
1093         for (VertexAttributeLocation location :
1094              IterateBitSet(renderPipeline.GetAttributeLocationsUsed())) {
1095             const VertexAttributeInfo& dawnInfo = renderPipeline.GetAttribute(location);
1096             tint::transform::VertexAttributeDescriptor tintInfo;
1097             tintInfo.format = ToTintVertexFormat(dawnInfo.format);
1098             tintInfo.offset = dawnInfo.offset;
1099             tintInfo.shader_location = static_cast<uint32_t>(static_cast<uint8_t>(location));
1100 
1101             uint8_t vertexBufferSlot = static_cast<uint8_t>(dawnInfo.vertexBufferSlot);
1102             cfg.vertex_state[vertexBufferSlot].attributes.push_back(tintInfo);
1103         }
1104 
1105         transformInputs->Add<tint::transform::VertexPulling::Config>(cfg);
1106     }
1107 
ValidateCompatibilityWithPipelineLayout(DeviceBase * device,const EntryPointMetadata & entryPoint,const PipelineLayoutBase * layout)1108     MaybeError ValidateCompatibilityWithPipelineLayout(DeviceBase* device,
1109                                                        const EntryPointMetadata& entryPoint,
1110                                                        const PipelineLayoutBase* layout) {
1111         for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
1112             DAWN_TRY_CONTEXT(ValidateCompatibilityWithBindGroupLayout(
1113                                  device, group, entryPoint, layout->GetBindGroupLayout(group)),
1114                              "validating the entry-point's compatibility for group %u with %s",
1115                              static_cast<uint32_t>(group), layout->GetBindGroupLayout(group));
1116         }
1117 
1118         for (BindGroupIndex group : IterateBitSet(~layout->GetBindGroupLayoutsMask())) {
1119             DAWN_INVALID_IF(entryPoint.bindings[group].size() > 0,
1120                             "The entry-point uses bindings in group %u but %s doesn't have a "
1121                             "BindGroupLayout for this index",
1122                             static_cast<uint32_t>(group), layout);
1123         }
1124 
1125         // Validate that filtering samplers are not used with unfilterable textures.
1126         for (const auto& pair : entryPoint.samplerTexturePairs) {
1127             const BindGroupLayoutBase* samplerBGL = layout->GetBindGroupLayout(pair.sampler.group);
1128             const BindingInfo& samplerInfo =
1129                 samplerBGL->GetBindingInfo(samplerBGL->GetBindingIndex(pair.sampler.binding));
1130             if (samplerInfo.sampler.type != wgpu::SamplerBindingType::Filtering) {
1131                 continue;
1132             }
1133             const BindGroupLayoutBase* textureBGL = layout->GetBindGroupLayout(pair.texture.group);
1134             const BindingInfo& textureInfo =
1135                 textureBGL->GetBindingInfo(textureBGL->GetBindingIndex(pair.texture.binding));
1136 
1137             ASSERT(textureInfo.bindingType != BindingInfoType::Buffer &&
1138                    textureInfo.bindingType != BindingInfoType::Sampler &&
1139                    textureInfo.bindingType != BindingInfoType::StorageTexture);
1140 
1141             if (textureInfo.bindingType != BindingInfoType::Texture) {
1142                 continue;
1143             }
1144 
1145             // Uint/sint can't be statically used with a sampler, so they any
1146             // texture bindings reflected must be float or depth textures. If
1147             // the shader uses a float/depth texture but the bind group layout
1148             // specifies a uint/sint texture binding,
1149             // |ValidateCompatibilityWithBindGroupLayout| will fail since the
1150             // sampleType does not match.
1151             ASSERT(textureInfo.texture.sampleType != wgpu::TextureSampleType::Undefined &&
1152                    textureInfo.texture.sampleType != wgpu::TextureSampleType::Uint &&
1153                    textureInfo.texture.sampleType != wgpu::TextureSampleType::Sint);
1154 
1155             DAWN_INVALID_IF(
1156                 textureInfo.texture.sampleType == wgpu::TextureSampleType::UnfilterableFloat,
1157                 "Texture binding (group:%u, binding:%u) is %s but used statically with a sampler "
1158                 "(group:%u, binding:%u) that's %s",
1159                 static_cast<uint32_t>(pair.texture.group),
1160                 static_cast<uint32_t>(pair.texture.binding),
1161                 wgpu::TextureSampleType::UnfilterableFloat,
1162                 static_cast<uint32_t>(pair.sampler.group),
1163                 static_cast<uint32_t>(pair.sampler.binding), wgpu::SamplerBindingType::Filtering);
1164         }
1165 
1166         return {};
1167     }
1168 
1169     // ShaderModuleBase
1170 
ShaderModuleBase(DeviceBase * device,const ShaderModuleDescriptor * descriptor,ApiObjectBase::UntrackedByDeviceTag tag)1171     ShaderModuleBase::ShaderModuleBase(DeviceBase* device,
1172                                        const ShaderModuleDescriptor* descriptor,
1173                                        ApiObjectBase::UntrackedByDeviceTag tag)
1174         : ApiObjectBase(device, descriptor->label), mType(Type::Undefined) {
1175         ASSERT(descriptor->nextInChain != nullptr);
1176         const ShaderModuleSPIRVDescriptor* spirvDesc = nullptr;
1177         FindInChain(descriptor->nextInChain, &spirvDesc);
1178         const ShaderModuleWGSLDescriptor* wgslDesc = nullptr;
1179         FindInChain(descriptor->nextInChain, &wgslDesc);
1180         ASSERT(spirvDesc || wgslDesc);
1181 
1182         if (spirvDesc) {
1183             mType = Type::Spirv;
1184             mOriginalSpirv.assign(spirvDesc->code, spirvDesc->code + spirvDesc->codeSize);
1185         } else if (wgslDesc) {
1186             mType = Type::Wgsl;
1187             mWgsl = std::string(wgslDesc->source);
1188         }
1189     }
1190 
ShaderModuleBase(DeviceBase * device,const ShaderModuleDescriptor * descriptor)1191     ShaderModuleBase::ShaderModuleBase(DeviceBase* device, const ShaderModuleDescriptor* descriptor)
1192         : ShaderModuleBase(device, descriptor, kUntrackedByDevice) {
1193         TrackInDevice();
1194     }
1195 
ShaderModuleBase(DeviceBase * device)1196     ShaderModuleBase::ShaderModuleBase(DeviceBase* device)
1197         : ApiObjectBase(device, kLabelNotImplemented) {
1198         TrackInDevice();
1199     }
1200 
ShaderModuleBase(DeviceBase * device,ObjectBase::ErrorTag tag)1201     ShaderModuleBase::ShaderModuleBase(DeviceBase* device, ObjectBase::ErrorTag tag)
1202         : ApiObjectBase(device, tag), mType(Type::Undefined) {
1203     }
1204 
1205     ShaderModuleBase::~ShaderModuleBase() = default;
1206 
DestroyImpl()1207     void ShaderModuleBase::DestroyImpl() {
1208         if (IsCachedReference()) {
1209             // Do not uncache the actual cached object if we are a blueprint.
1210             GetDevice()->UncacheShaderModule(this);
1211         }
1212     }
1213 
1214     // static
MakeError(DeviceBase * device)1215     Ref<ShaderModuleBase> ShaderModuleBase::MakeError(DeviceBase* device) {
1216         return AcquireRef(new ShaderModuleBase(device, ObjectBase::kError));
1217     }
1218 
GetType() const1219     ObjectType ShaderModuleBase::GetType() const {
1220         return ObjectType::ShaderModule;
1221     }
1222 
HasEntryPoint(const std::string & entryPoint) const1223     bool ShaderModuleBase::HasEntryPoint(const std::string& entryPoint) const {
1224         return mEntryPoints.count(entryPoint) > 0;
1225     }
1226 
GetEntryPoint(const std::string & entryPoint) const1227     const EntryPointMetadata& ShaderModuleBase::GetEntryPoint(const std::string& entryPoint) const {
1228         ASSERT(HasEntryPoint(entryPoint));
1229         return *mEntryPoints.at(entryPoint);
1230     }
1231 
ComputeContentHash()1232     size_t ShaderModuleBase::ComputeContentHash() {
1233         ObjectContentHasher recorder;
1234         recorder.Record(mType);
1235         recorder.Record(mOriginalSpirv);
1236         recorder.Record(mWgsl);
1237         return recorder.GetContentHash();
1238     }
1239 
operator ()(const ShaderModuleBase * a,const ShaderModuleBase * b) const1240     bool ShaderModuleBase::EqualityFunc::operator()(const ShaderModuleBase* a,
1241                                                     const ShaderModuleBase* b) const {
1242         return a->mType == b->mType && a->mOriginalSpirv == b->mOriginalSpirv &&
1243                a->mWgsl == b->mWgsl;
1244     }
1245 
GetTintProgram() const1246     const tint::Program* ShaderModuleBase::GetTintProgram() const {
1247         ASSERT(mTintProgram);
1248         return mTintProgram.get();
1249     }
1250 
APIGetCompilationInfo(wgpu::CompilationInfoCallback callback,void * userdata)1251     void ShaderModuleBase::APIGetCompilationInfo(wgpu::CompilationInfoCallback callback,
1252                                                  void* userdata) {
1253         if (callback == nullptr) {
1254             return;
1255         }
1256 
1257         callback(WGPUCompilationInfoRequestStatus_Success,
1258                  mCompilationMessages->GetCompilationInfo(), userdata);
1259     }
1260 
InjectCompilationMessages(std::unique_ptr<OwnedCompilationMessages> compilationMessages)1261     void ShaderModuleBase::InjectCompilationMessages(
1262         std::unique_ptr<OwnedCompilationMessages> compilationMessages) {
1263         // TODO(dawn:944): ensure the InjectCompilationMessages is properly handled for shader
1264         // module returned from cache.
1265         // InjectCompilationMessages should be called only once for a shader module, after it is
1266         // created. However currently InjectCompilationMessages may be called on a shader module
1267         // returned from cache rather than newly created, and violate the rule. We just skip the
1268         // injection in this case for now, but a proper solution including ensure the cache goes
1269         // before the validation is required.
1270         if (mCompilationMessages != nullptr) {
1271             return;
1272         }
1273         // Move the compilationMessages into the shader module and emit the tint errors and warnings
1274         mCompilationMessages = std::move(compilationMessages);
1275 
1276         // Emit the formatted Tint errors and warnings within the moved compilationMessages
1277         const std::vector<std::string>& formattedTintMessages =
1278             mCompilationMessages->GetFormattedTintMessages();
1279         if (formattedTintMessages.empty()) {
1280             return;
1281         }
1282         std::ostringstream t;
1283         for (auto pMessage = formattedTintMessages.begin(); pMessage != formattedTintMessages.end();
1284              pMessage++) {
1285             if (pMessage != formattedTintMessages.begin()) {
1286                 t << std::endl;
1287             }
1288             t << *pMessage;
1289         }
1290         this->GetDevice()->EmitLog(WGPULoggingType_Warning, t.str().c_str());
1291     }
1292 
GetCompilationMessages() const1293     OwnedCompilationMessages* ShaderModuleBase::GetCompilationMessages() const {
1294         return mCompilationMessages.get();
1295     }
1296 
InitializeBase(ShaderModuleParseResult * parseResult)1297     MaybeError ShaderModuleBase::InitializeBase(ShaderModuleParseResult* parseResult) {
1298         mTintProgram = std::move(parseResult->tintProgram);
1299         mTintSource = std::move(parseResult->tintSource);
1300 
1301         DAWN_TRY_ASSIGN(mEntryPoints, ReflectShaderUsingTint(GetDevice(), mTintProgram.get()));
1302         return {};
1303     }
1304 
operator ()(const PipelineLayoutEntryPointPair & pair) const1305     size_t PipelineLayoutEntryPointPairHashFunc::operator()(
1306         const PipelineLayoutEntryPointPair& pair) const {
1307         size_t hash = 0;
1308         HashCombine(&hash, pair.first, pair.second);
1309         return hash;
1310     }
1311 
1312 }  // namespace dawn_native
1313