• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2017 The Dawn Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "dawn_native/d3d12/ShaderModuleD3D12.h"
16 
17 #include "common/Assert.h"
18 #include "common/BitSetIterator.h"
19 #include "common/Log.h"
20 #include "common/WindowsUtils.h"
21 #include "dawn_native/Pipeline.h"
22 #include "dawn_native/TintUtils.h"
23 #include "dawn_native/d3d12/BindGroupLayoutD3D12.h"
24 #include "dawn_native/d3d12/D3D12Error.h"
25 #include "dawn_native/d3d12/DeviceD3D12.h"
26 #include "dawn_native/d3d12/PipelineLayoutD3D12.h"
27 #include "dawn_native/d3d12/PlatformFunctions.h"
28 #include "dawn_native/d3d12/UtilsD3D12.h"
29 
30 #include <d3dcompiler.h>
31 
32 #include <tint/tint.h>
33 #include <map>
34 #include <sstream>
35 #include <unordered_map>
36 
37 namespace dawn_native { namespace d3d12 {
38 
39     namespace {
GetDXCompilerVersion(ComPtr<IDxcValidator> dxcValidator)40         ResultOrError<uint64_t> GetDXCompilerVersion(ComPtr<IDxcValidator> dxcValidator) {
41             ComPtr<IDxcVersionInfo> versionInfo;
42             DAWN_TRY(CheckHRESULT(dxcValidator.As(&versionInfo),
43                                   "D3D12 QueryInterface IDxcValidator to IDxcVersionInfo"));
44 
45             uint32_t compilerMajor, compilerMinor;
46             DAWN_TRY(CheckHRESULT(versionInfo->GetVersion(&compilerMajor, &compilerMinor),
47                                   "IDxcVersionInfo::GetVersion"));
48 
49             // Pack both into a single version number.
50             return (uint64_t(compilerMajor) << uint64_t(32)) + compilerMinor;
51         }
52 
GetD3DCompilerVersion()53         uint64_t GetD3DCompilerVersion() {
54             return D3D_COMPILER_VERSION;
55         }
56 
57         struct CompareBindingPoint {
operator ()dawn_native::d3d12::__anone61f89010111::CompareBindingPoint58             constexpr bool operator()(const tint::transform::BindingPoint& lhs,
59                                       const tint::transform::BindingPoint& rhs) const {
60                 if (lhs.group != rhs.group) {
61                     return lhs.group < rhs.group;
62                 } else {
63                     return lhs.binding < rhs.binding;
64                 }
65             }
66         };
67 
Serialize(std::stringstream & output,const tint::ast::Access & access)68         void Serialize(std::stringstream& output, const tint::ast::Access& access) {
69             output << access;
70         }
71 
Serialize(std::stringstream & output,const tint::transform::BindingPoint & binding_point)72         void Serialize(std::stringstream& output,
73                        const tint::transform::BindingPoint& binding_point) {
74             output << "(BindingPoint";
75             output << " group=" << binding_point.group;
76             output << " binding=" << binding_point.binding;
77             output << ")";
78         }
79 
80         template <typename T,
81                   typename = typename std::enable_if<std::is_fundamental<T>::value>::type>
Serialize(std::stringstream & output,const T & val)82         void Serialize(std::stringstream& output, const T& val) {
83             output << val;
84         }
85 
86         template <typename T>
Serialize(std::stringstream & output,const std::unordered_map<tint::transform::BindingPoint,T> & map)87         void Serialize(std::stringstream& output,
88                        const std::unordered_map<tint::transform::BindingPoint, T>& map) {
89             output << "(map";
90 
91             std::map<tint::transform::BindingPoint, T, CompareBindingPoint> sorted(map.begin(),
92                                                                                    map.end());
93             for (auto& entry : sorted) {
94                 output << " ";
95                 Serialize(output, entry.first);
96                 output << "=";
97                 Serialize(output, entry.second);
98             }
99             output << ")";
100         }
101 
Serialize(std::stringstream & output,const tint::writer::ArrayLengthFromUniformOptions & arrayLengthFromUniform)102         void Serialize(std::stringstream& output,
103                        const tint::writer::ArrayLengthFromUniformOptions& arrayLengthFromUniform) {
104             output << "(ArrayLengthFromUniformOptions";
105             output << " ubo_binding=";
106             Serialize(output, arrayLengthFromUniform.ubo_binding);
107             output << " bindpoint_to_size_index=";
108             Serialize(output, arrayLengthFromUniform.bindpoint_to_size_index);
109             output << ")";
110         }
111 
112         // 32 bit float has 7 decimal digits of precision so setting n to 8 should be enough
FloatToStringWithPrecision(float v,std::streamsize n=8)113         std::string FloatToStringWithPrecision(float v, std::streamsize n = 8) {
114             std::ostringstream out;
115             out.precision(n);
116             out << std::fixed << v;
117             return out.str();
118         }
119 
GetHLSLValueString(EntryPointMetadata::OverridableConstant::Type dawnType,const OverridableConstantScalar * entry,double value=0)120         std::string GetHLSLValueString(EntryPointMetadata::OverridableConstant::Type dawnType,
121                                        const OverridableConstantScalar* entry,
122                                        double value = 0) {
123             switch (dawnType) {
124                 case EntryPointMetadata::OverridableConstant::Type::Boolean:
125                     return std::to_string(entry ? entry->b : static_cast<int32_t>(value));
126                 case EntryPointMetadata::OverridableConstant::Type::Float32:
127                     return FloatToStringWithPrecision(entry ? entry->f32
128                                                             : static_cast<float>(value));
129                 case EntryPointMetadata::OverridableConstant::Type::Int32:
130                     return std::to_string(entry ? entry->i32 : static_cast<int32_t>(value));
131                 case EntryPointMetadata::OverridableConstant::Type::Uint32:
132                     return std::to_string(entry ? entry->u32 : static_cast<uint32_t>(value));
133                 default:
134                     UNREACHABLE();
135             }
136         }
137 
138         constexpr char kSpecConstantPrefix[] = "WGSL_SPEC_CONSTANT_";
139 
GetOverridableConstantsDefines(std::vector<std::pair<std::string,std::string>> * defineStrings,const PipelineConstantEntries * pipelineConstantEntries,const EntryPointMetadata::OverridableConstantsMap * shaderEntryPointConstants)140         void GetOverridableConstantsDefines(
141             std::vector<std::pair<std::string, std::string>>* defineStrings,
142             const PipelineConstantEntries* pipelineConstantEntries,
143             const EntryPointMetadata::OverridableConstantsMap* shaderEntryPointConstants) {
144             std::unordered_set<std::string> overriddenConstants;
145 
146             // Set pipeline overridden values
147             for (const auto& pipelineConstant : *pipelineConstantEntries) {
148                 const std::string& name = pipelineConstant.first;
149                 double value = pipelineConstant.second;
150 
151                 overriddenConstants.insert(name);
152 
153                 // This is already validated so `name` must exist
154                 const auto& moduleConstant = shaderEntryPointConstants->at(name);
155 
156                 defineStrings->emplace_back(
157                     kSpecConstantPrefix + std::to_string(static_cast<int32_t>(moduleConstant.id)),
158                     GetHLSLValueString(moduleConstant.type, nullptr, value));
159             }
160 
161             // Set shader initialized default values
162             for (const auto& iter : *shaderEntryPointConstants) {
163                 const std::string& name = iter.first;
164                 if (overriddenConstants.count(name) != 0) {
165                     // This constant already has overridden value
166                     continue;
167                 }
168 
169                 const auto& moduleConstant = shaderEntryPointConstants->at(name);
170 
171                 // Uninitialized default values are okay since they ar only defined to pass
172                 // compilation but not used
173                 defineStrings->emplace_back(
174                     kSpecConstantPrefix + std::to_string(static_cast<int32_t>(moduleConstant.id)),
175                     GetHLSLValueString(moduleConstant.type, &moduleConstant.defaultValue));
176             }
177         }
178 
179         // The inputs to a shader compilation. These have been intentionally isolated from the
180         // device to help ensure that the pipeline cache key contains all inputs for compilation.
181         struct ShaderCompilationRequest {
182             enum Compiler { FXC, DXC };
183 
184             // Common inputs
185             Compiler compiler;
186             const tint::Program* program;
187             const char* entryPointName;
188             SingleShaderStage stage;
189             uint32_t compileFlags;
190             bool disableSymbolRenaming;
191             tint::transform::BindingRemapper::BindingPoints remappedBindingPoints;
192             tint::transform::BindingRemapper::AccessControls remappedAccessControls;
193             bool isRobustnessEnabled;
194             bool usesNumWorkgroups;
195             uint32_t numWorkgroupsRegisterSpace;
196             uint32_t numWorkgroupsShaderRegister;
197             tint::writer::ArrayLengthFromUniformOptions arrayLengthFromUniform;
198             std::vector<std::pair<std::string, std::string>> defineStrings;
199 
200             // FXC/DXC common inputs
201             bool disableWorkgroupInit;
202 
203             // FXC inputs
204             uint64_t fxcVersion;
205 
206             // DXC inputs
207             uint64_t dxcVersion;
208             const D3D12DeviceInfo* deviceInfo;
209             bool hasShaderFloat16Feature;
210 
Createdawn_native::d3d12::__anone61f89010111::ShaderCompilationRequest211             static ResultOrError<ShaderCompilationRequest> Create(
212                 const char* entryPointName,
213                 SingleShaderStage stage,
214                 const PipelineLayout* layout,
215                 uint32_t compileFlags,
216                 const Device* device,
217                 const tint::Program* program,
218                 const EntryPointMetadata& entryPoint,
219                 const ProgrammableStage& programmableStage) {
220                 Compiler compiler;
221                 uint64_t dxcVersion = 0;
222                 if (device->IsToggleEnabled(Toggle::UseDXC)) {
223                     compiler = Compiler::DXC;
224                     DAWN_TRY_ASSIGN(dxcVersion, GetDXCompilerVersion(device->GetDxcValidator()));
225                 } else {
226                     compiler = Compiler::FXC;
227                 }
228 
229                 using tint::transform::BindingPoint;
230                 using tint::transform::BindingRemapper;
231 
232                 BindingRemapper::BindingPoints remappedBindingPoints;
233                 BindingRemapper::AccessControls remappedAccessControls;
234 
235                 tint::writer::ArrayLengthFromUniformOptions arrayLengthFromUniform;
236                 arrayLengthFromUniform.ubo_binding = {
237                     layout->GetDynamicStorageBufferLengthsRegisterSpace(),
238                     layout->GetDynamicStorageBufferLengthsShaderRegister()};
239 
240                 const BindingInfoArray& moduleBindingInfo = entryPoint.bindings;
241                 for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
242                     const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group));
243                     const auto& groupBindingInfo = moduleBindingInfo[group];
244 
245                     // d3d12::BindGroupLayout packs the bindings per HLSL register-space. We modify
246                     // the Tint AST to make the "bindings" decoration match the offset chosen by
247                     // d3d12::BindGroupLayout so that Tint produces HLSL with the correct registers
248                     // assigned to each interface variable.
249                     for (const auto& it : groupBindingInfo) {
250                         BindingNumber binding = it.first;
251                         auto const& bindingInfo = it.second;
252                         BindingIndex bindingIndex = bgl->GetBindingIndex(binding);
253                         BindingPoint srcBindingPoint{static_cast<uint32_t>(group),
254                                                      static_cast<uint32_t>(binding)};
255                         BindingPoint dstBindingPoint{static_cast<uint32_t>(group),
256                                                      bgl->GetShaderRegister(bindingIndex)};
257                         if (srcBindingPoint != dstBindingPoint) {
258                             remappedBindingPoints.emplace(srcBindingPoint, dstBindingPoint);
259                         }
260 
261                         // Declaring a read-only storage buffer in HLSL but specifying a storage
262                         // buffer in the BGL produces the wrong output. Force read-only storage
263                         // buffer bindings to be treated as UAV instead of SRV. Internal storage
264                         // buffer is a storage buffer used in the internal pipeline.
265                         const bool forceStorageBufferAsUAV =
266                             (bindingInfo.buffer.type == wgpu::BufferBindingType::ReadOnlyStorage &&
267                              (bgl->GetBindingInfo(bindingIndex).buffer.type ==
268                                   wgpu::BufferBindingType::Storage ||
269                               bgl->GetBindingInfo(bindingIndex).buffer.type ==
270                                   kInternalStorageBufferBinding));
271                         if (forceStorageBufferAsUAV) {
272                             remappedAccessControls.emplace(srcBindingPoint,
273                                                            tint::ast::Access::kReadWrite);
274                         }
275                     }
276 
277                     // Add arrayLengthFromUniform options
278                     {
279                         for (const auto& bindingAndRegisterOffset :
280                              layout->GetDynamicStorageBufferLengthInfo()[group]
281                                  .bindingAndRegisterOffsets) {
282                             BindingNumber binding = bindingAndRegisterOffset.binding;
283                             uint32_t registerOffset = bindingAndRegisterOffset.registerOffset;
284 
285                             BindingPoint bindingPoint{static_cast<uint32_t>(group),
286                                                       static_cast<uint32_t>(binding)};
287                             // Get the renamed binding point if it was remapped.
288                             auto it = remappedBindingPoints.find(bindingPoint);
289                             if (it != remappedBindingPoints.end()) {
290                                 bindingPoint = it->second;
291                             }
292 
293                             arrayLengthFromUniform.bindpoint_to_size_index.emplace(bindingPoint,
294                                                                                    registerOffset);
295                         }
296                     }
297                 }
298 
299                 ShaderCompilationRequest request;
300                 request.compiler = compiler;
301                 request.program = program;
302                 request.entryPointName = entryPointName;
303                 request.stage = stage;
304                 request.compileFlags = compileFlags;
305                 request.disableSymbolRenaming =
306                     device->IsToggleEnabled(Toggle::DisableSymbolRenaming);
307                 request.remappedBindingPoints = std::move(remappedBindingPoints);
308                 request.remappedAccessControls = std::move(remappedAccessControls);
309                 request.isRobustnessEnabled = device->IsRobustnessEnabled();
310                 request.disableWorkgroupInit =
311                     device->IsToggleEnabled(Toggle::DisableWorkgroupInit);
312                 request.usesNumWorkgroups = entryPoint.usesNumWorkgroups;
313                 request.numWorkgroupsShaderRegister = layout->GetNumWorkgroupsShaderRegister();
314                 request.numWorkgroupsRegisterSpace = layout->GetNumWorkgroupsRegisterSpace();
315                 request.arrayLengthFromUniform = std::move(arrayLengthFromUniform);
316                 request.fxcVersion = compiler == Compiler::FXC ? GetD3DCompilerVersion() : 0;
317                 request.dxcVersion = compiler == Compiler::DXC ? dxcVersion : 0;
318                 request.deviceInfo = &device->GetDeviceInfo();
319                 request.hasShaderFloat16Feature = device->IsFeatureEnabled(Feature::ShaderFloat16);
320 
321                 GetOverridableConstantsDefines(
322                     &request.defineStrings, &programmableStage.constants,
323                     &programmableStage.module->GetEntryPoint(programmableStage.entryPoint)
324                          .overridableConstants);
325 
326                 return std::move(request);
327             }
328 
CreateCacheKeydawn_native::d3d12::__anone61f89010111::ShaderCompilationRequest329             ResultOrError<PersistentCacheKey> CreateCacheKey() const {
330                 // Generate the WGSL from the Tint program so it's normalized.
331                 // TODO(tint:1180): Consider using a binary serialization of the tint AST for a more
332                 // compact representation.
333                 auto result = tint::writer::wgsl::Generate(program, tint::writer::wgsl::Options{});
334                 if (!result.success) {
335                     std::ostringstream errorStream;
336                     errorStream << "Tint WGSL failure:" << std::endl;
337                     errorStream << "Generator: " << result.error << std::endl;
338                     return DAWN_INTERNAL_ERROR(errorStream.str().c_str());
339                 }
340 
341                 std::stringstream stream;
342 
343                 // Prefix the key with the type to avoid collisions from another type that could
344                 // have the same key.
345                 stream << static_cast<uint32_t>(PersistentKeyType::Shader);
346                 stream << "\n";
347 
348                 stream << result.wgsl.length();
349                 stream << "\n";
350 
351                 stream << result.wgsl;
352                 stream << "\n";
353 
354                 stream << "(ShaderCompilationRequest";
355                 stream << " compiler=" << compiler;
356                 stream << " entryPointName=" << entryPointName;
357                 stream << " stage=" << uint32_t(stage);
358                 stream << " compileFlags=" << compileFlags;
359                 stream << " disableSymbolRenaming=" << disableSymbolRenaming;
360 
361                 stream << " remappedBindingPoints=";
362                 Serialize(stream, remappedBindingPoints);
363 
364                 stream << " remappedAccessControls=";
365                 Serialize(stream, remappedAccessControls);
366 
367                 stream << " useNumWorkgroups=" << usesNumWorkgroups;
368                 stream << " numWorkgroupsRegisterSpace=" << numWorkgroupsRegisterSpace;
369                 stream << " numWorkgroupsShaderRegister=" << numWorkgroupsShaderRegister;
370 
371                 stream << " arrayLengthFromUniform=";
372                 Serialize(stream, arrayLengthFromUniform);
373 
374                 stream << " shaderModel=" << deviceInfo->shaderModel;
375                 stream << " disableWorkgroupInit=" << disableWorkgroupInit;
376                 stream << " isRobustnessEnabled=" << isRobustnessEnabled;
377                 stream << " fxcVersion=" << fxcVersion;
378                 stream << " dxcVersion=" << dxcVersion;
379                 stream << " hasShaderFloat16Feature=" << hasShaderFloat16Feature;
380 
381                 stream << " defines={";
382                 for (const auto& it : defineStrings) {
383                     stream << " <" << it.first << "," << it.second << ">";
384                 }
385                 stream << " }";
386 
387                 stream << ")";
388                 stream << "\n";
389 
390                 return PersistentCacheKey(std::istreambuf_iterator<char>{stream},
391                                           std::istreambuf_iterator<char>{});
392             }
393         };
394 
GetDXCArguments(uint32_t compileFlags,bool enable16BitTypes)395         std::vector<const wchar_t*> GetDXCArguments(uint32_t compileFlags, bool enable16BitTypes) {
396             std::vector<const wchar_t*> arguments;
397             if (compileFlags & D3DCOMPILE_ENABLE_BACKWARDS_COMPATIBILITY) {
398                 arguments.push_back(L"/Gec");
399             }
400             if (compileFlags & D3DCOMPILE_IEEE_STRICTNESS) {
401                 arguments.push_back(L"/Gis");
402             }
403             constexpr uint32_t d3dCompileFlagsBits = D3DCOMPILE_OPTIMIZATION_LEVEL2;
404             if (compileFlags & d3dCompileFlagsBits) {
405                 switch (compileFlags & D3DCOMPILE_OPTIMIZATION_LEVEL2) {
406                     case D3DCOMPILE_OPTIMIZATION_LEVEL0:
407                         arguments.push_back(L"/O0");
408                         break;
409                     case D3DCOMPILE_OPTIMIZATION_LEVEL2:
410                         arguments.push_back(L"/O2");
411                         break;
412                     case D3DCOMPILE_OPTIMIZATION_LEVEL3:
413                         arguments.push_back(L"/O3");
414                         break;
415                 }
416             }
417             if (compileFlags & D3DCOMPILE_DEBUG) {
418                 arguments.push_back(L"/Zi");
419             }
420             if (compileFlags & D3DCOMPILE_PACK_MATRIX_ROW_MAJOR) {
421                 arguments.push_back(L"/Zpr");
422             }
423             if (compileFlags & D3DCOMPILE_PACK_MATRIX_COLUMN_MAJOR) {
424                 arguments.push_back(L"/Zpc");
425             }
426             if (compileFlags & D3DCOMPILE_AVOID_FLOW_CONTROL) {
427                 arguments.push_back(L"/Gfa");
428             }
429             if (compileFlags & D3DCOMPILE_PREFER_FLOW_CONTROL) {
430                 arguments.push_back(L"/Gfp");
431             }
432             if (compileFlags & D3DCOMPILE_RESOURCES_MAY_ALIAS) {
433                 arguments.push_back(L"/res_may_alias");
434             }
435 
436             if (enable16BitTypes) {
437                 // enable-16bit-types are only allowed in -HV 2018 (default)
438                 arguments.push_back(L"/enable-16bit-types");
439             }
440 
441             arguments.push_back(L"-HV");
442             arguments.push_back(L"2018");
443 
444             return arguments;
445         }
446 
CompileShaderDXC(IDxcLibrary * dxcLibrary,IDxcCompiler * dxcCompiler,const ShaderCompilationRequest & request,const std::string & hlslSource)447         ResultOrError<ComPtr<IDxcBlob>> CompileShaderDXC(IDxcLibrary* dxcLibrary,
448                                                          IDxcCompiler* dxcCompiler,
449                                                          const ShaderCompilationRequest& request,
450                                                          const std::string& hlslSource) {
451             ComPtr<IDxcBlobEncoding> sourceBlob;
452             DAWN_TRY(
453                 CheckHRESULT(dxcLibrary->CreateBlobWithEncodingOnHeapCopy(
454                                  hlslSource.c_str(), hlslSource.length(), CP_UTF8, &sourceBlob),
455                              "DXC create blob"));
456 
457             std::wstring entryPointW;
458             DAWN_TRY_ASSIGN(entryPointW, ConvertStringToWstring(request.entryPointName));
459 
460             std::vector<const wchar_t*> arguments =
461                 GetDXCArguments(request.compileFlags, request.hasShaderFloat16Feature);
462 
463             // Build defines for overridable constants
464             std::vector<std::pair<std::wstring, std::wstring>> defineStrings;
465             defineStrings.reserve(request.defineStrings.size());
466             for (const auto& it : request.defineStrings) {
467                 defineStrings.emplace_back(UTF8ToWStr(it.first.c_str()),
468                                            UTF8ToWStr(it.second.c_str()));
469             }
470 
471             std::vector<DxcDefine> dxcDefines;
472             dxcDefines.reserve(defineStrings.size());
473             for (const auto& d : defineStrings) {
474                 dxcDefines.push_back({d.first.c_str(), d.second.c_str()});
475             }
476 
477             ComPtr<IDxcOperationResult> result;
478             DAWN_TRY(CheckHRESULT(
479                 dxcCompiler->Compile(sourceBlob.Get(), nullptr, entryPointW.c_str(),
480                                      request.deviceInfo->shaderProfiles[request.stage].c_str(),
481                                      arguments.data(), arguments.size(), dxcDefines.data(),
482                                      dxcDefines.size(), nullptr, &result),
483                 "DXC compile"));
484 
485             HRESULT hr;
486             DAWN_TRY(CheckHRESULT(result->GetStatus(&hr), "DXC get status"));
487 
488             if (FAILED(hr)) {
489                 ComPtr<IDxcBlobEncoding> errors;
490                 DAWN_TRY(CheckHRESULT(result->GetErrorBuffer(&errors), "DXC get error buffer"));
491 
492                 return DAWN_FORMAT_VALIDATION_ERROR("DXC compile failed with: %s",
493                                                     static_cast<char*>(errors->GetBufferPointer()));
494             }
495 
496             ComPtr<IDxcBlob> compiledShader;
497             DAWN_TRY(CheckHRESULT(result->GetResult(&compiledShader), "DXC get result"));
498             return std::move(compiledShader);
499         }
500 
CompileFlagsToStringFXC(uint32_t compileFlags)501         std::string CompileFlagsToStringFXC(uint32_t compileFlags) {
502             struct Flag {
503                 uint32_t value;
504                 const char* name;
505             };
506             constexpr Flag flags[] = {
507             // Populated from d3dcompiler.h
508 #define F(f) Flag{f, #f}
509                 F(D3DCOMPILE_DEBUG),
510                 F(D3DCOMPILE_SKIP_VALIDATION),
511                 F(D3DCOMPILE_SKIP_OPTIMIZATION),
512                 F(D3DCOMPILE_PACK_MATRIX_ROW_MAJOR),
513                 F(D3DCOMPILE_PACK_MATRIX_COLUMN_MAJOR),
514                 F(D3DCOMPILE_PARTIAL_PRECISION),
515                 F(D3DCOMPILE_FORCE_VS_SOFTWARE_NO_OPT),
516                 F(D3DCOMPILE_FORCE_PS_SOFTWARE_NO_OPT),
517                 F(D3DCOMPILE_NO_PRESHADER),
518                 F(D3DCOMPILE_AVOID_FLOW_CONTROL),
519                 F(D3DCOMPILE_PREFER_FLOW_CONTROL),
520                 F(D3DCOMPILE_ENABLE_STRICTNESS),
521                 F(D3DCOMPILE_ENABLE_BACKWARDS_COMPATIBILITY),
522                 F(D3DCOMPILE_IEEE_STRICTNESS),
523                 F(D3DCOMPILE_RESERVED16),
524                 F(D3DCOMPILE_RESERVED17),
525                 F(D3DCOMPILE_WARNINGS_ARE_ERRORS),
526                 F(D3DCOMPILE_RESOURCES_MAY_ALIAS),
527                 F(D3DCOMPILE_ENABLE_UNBOUNDED_DESCRIPTOR_TABLES),
528                 F(D3DCOMPILE_ALL_RESOURCES_BOUND),
529                 F(D3DCOMPILE_DEBUG_NAME_FOR_SOURCE),
530                 F(D3DCOMPILE_DEBUG_NAME_FOR_BINARY),
531 #undef F
532             };
533 
534             std::string result;
535             for (const Flag& f : flags) {
536                 if ((compileFlags & f.value) != 0) {
537                     result += f.name + std::string("\n");
538                 }
539             }
540 
541             // Optimization level must be handled separately as two bits are used, and the values
542             // don't map neatly to 0-3.
543             constexpr uint32_t d3dCompileFlagsBits = D3DCOMPILE_OPTIMIZATION_LEVEL2;
544             switch (compileFlags & d3dCompileFlagsBits) {
545                 case D3DCOMPILE_OPTIMIZATION_LEVEL0:
546                     result += "D3DCOMPILE_OPTIMIZATION_LEVEL0";
547                     break;
548                 case D3DCOMPILE_OPTIMIZATION_LEVEL1:
549                     result += "D3DCOMPILE_OPTIMIZATION_LEVEL1";
550                     break;
551                 case D3DCOMPILE_OPTIMIZATION_LEVEL2:
552                     result += "D3DCOMPILE_OPTIMIZATION_LEVEL2";
553                     break;
554                 case D3DCOMPILE_OPTIMIZATION_LEVEL3:
555                     result += "D3DCOMPILE_OPTIMIZATION_LEVEL3";
556                     break;
557             }
558             result += std::string("\n");
559 
560             return result;
561         }
562 
CompileShaderFXC(const PlatformFunctions * functions,const ShaderCompilationRequest & request,const std::string & hlslSource)563         ResultOrError<ComPtr<ID3DBlob>> CompileShaderFXC(const PlatformFunctions* functions,
564                                                          const ShaderCompilationRequest& request,
565                                                          const std::string& hlslSource) {
566             const char* targetProfile = nullptr;
567             switch (request.stage) {
568                 case SingleShaderStage::Vertex:
569                     targetProfile = "vs_5_1";
570                     break;
571                 case SingleShaderStage::Fragment:
572                     targetProfile = "ps_5_1";
573                     break;
574                 case SingleShaderStage::Compute:
575                     targetProfile = "cs_5_1";
576                     break;
577             }
578 
579             ComPtr<ID3DBlob> compiledShader;
580             ComPtr<ID3DBlob> errors;
581 
582             // Build defines for overridable constants
583             const D3D_SHADER_MACRO* pDefines = nullptr;
584             std::vector<D3D_SHADER_MACRO> fxcDefines;
585             if (request.defineStrings.size() > 0) {
586                 fxcDefines.reserve(request.defineStrings.size() + 1);
587                 for (const auto& d : request.defineStrings) {
588                     fxcDefines.push_back({d.first.c_str(), d.second.c_str()});
589                 }
590                 // d3dCompile D3D_SHADER_MACRO* pDefines is a nullptr terminated array
591                 fxcDefines.push_back({nullptr, nullptr});
592                 pDefines = fxcDefines.data();
593             }
594 
595             DAWN_INVALID_IF(FAILED(functions->d3dCompile(
596                                 hlslSource.c_str(), hlslSource.length(), nullptr, pDefines, nullptr,
597                                 request.entryPointName, targetProfile, request.compileFlags, 0,
598                                 &compiledShader, &errors)),
599                             "D3D compile failed with: %s",
600                             static_cast<char*>(errors->GetBufferPointer()));
601 
602             return std::move(compiledShader);
603         }
604 
TranslateToHLSL(const ShaderCompilationRequest & request,std::string * remappedEntryPointName)605         ResultOrError<std::string> TranslateToHLSL(const ShaderCompilationRequest& request,
606                                                    std::string* remappedEntryPointName) {
607             std::ostringstream errorStream;
608             errorStream << "Tint HLSL failure:" << std::endl;
609 
610             tint::transform::Manager transformManager;
611             tint::transform::DataMap transformInputs;
612 
613             if (request.isRobustnessEnabled) {
614                 transformManager.Add<tint::transform::Robustness>();
615             }
616 
617             transformManager.Add<tint::transform::BindingRemapper>();
618 
619             transformManager.Add<tint::transform::SingleEntryPoint>();
620             transformInputs.Add<tint::transform::SingleEntryPoint::Config>(request.entryPointName);
621 
622             transformManager.Add<tint::transform::Renamer>();
623 
624             if (request.disableSymbolRenaming) {
625                 // We still need to rename HLSL reserved keywords
626                 transformInputs.Add<tint::transform::Renamer::Config>(
627                     tint::transform::Renamer::Target::kHlslKeywords);
628             }
629 
630             // D3D12 registers like `t3` and `c3` have the same bindingOffset number in
631             // the remapping but should not be considered a collision because they have
632             // different types.
633             const bool mayCollide = true;
634             transformInputs.Add<tint::transform::BindingRemapper::Remappings>(
635                 std::move(request.remappedBindingPoints), std::move(request.remappedAccessControls),
636                 mayCollide);
637 
638             tint::Program transformedProgram;
639             tint::transform::DataMap transformOutputs;
640             DAWN_TRY_ASSIGN(transformedProgram,
641                             RunTransforms(&transformManager, request.program, transformInputs,
642                                           &transformOutputs, nullptr));
643 
644             if (auto* data = transformOutputs.Get<tint::transform::Renamer::Data>()) {
645                 auto it = data->remappings.find(request.entryPointName);
646                 if (it != data->remappings.end()) {
647                     *remappedEntryPointName = it->second;
648                 } else {
649                     DAWN_INVALID_IF(!request.disableSymbolRenaming,
650                                     "Could not find remapped name for entry point.");
651 
652                     *remappedEntryPointName = request.entryPointName;
653                 }
654             } else {
655                 return DAWN_FORMAT_VALIDATION_ERROR("Transform output missing renamer data.");
656             }
657 
658             tint::writer::hlsl::Options options;
659             options.disable_workgroup_init = request.disableWorkgroupInit;
660             if (request.usesNumWorkgroups) {
661                 options.root_constant_binding_point.group = request.numWorkgroupsRegisterSpace;
662                 options.root_constant_binding_point.binding = request.numWorkgroupsShaderRegister;
663             }
664             // TODO(dawn:549): HLSL generation outputs the indices into the
665             // array_length_from_uniform buffer that were actually used. When the blob cache can
666             // store more than compiled shaders, we should reflect these used indices and store
667             // them as well. This would allow us to only upload root constants that are actually
668             // read by the shader.
669             options.array_length_from_uniform = request.arrayLengthFromUniform;
670             auto result = tint::writer::hlsl::Generate(&transformedProgram, options);
671             DAWN_INVALID_IF(!result.success, "An error occured while generating HLSL: %s",
672                             result.error);
673 
674             return std::move(result.hlsl);
675         }
676 
677         template <typename F>
CompileShader(const PlatformFunctions * functions,IDxcLibrary * dxcLibrary,IDxcCompiler * dxcCompiler,ShaderCompilationRequest && request,bool dumpShaders,F && DumpShadersEmitLog,CompiledShader * compiledShader)678         MaybeError CompileShader(const PlatformFunctions* functions,
679                                  IDxcLibrary* dxcLibrary,
680                                  IDxcCompiler* dxcCompiler,
681                                  ShaderCompilationRequest&& request,
682                                  bool dumpShaders,
683                                  F&& DumpShadersEmitLog,
684                                  CompiledShader* compiledShader) {
685             // Compile the source shader to HLSL.
686             std::string hlslSource;
687             std::string remappedEntryPoint;
688             DAWN_TRY_ASSIGN(hlslSource, TranslateToHLSL(request, &remappedEntryPoint));
689             if (dumpShaders) {
690                 std::ostringstream dumpedMsg;
691                 dumpedMsg << "/* Dumped generated HLSL */" << std::endl << hlslSource;
692                 DumpShadersEmitLog(WGPULoggingType_Info, dumpedMsg.str().c_str());
693             }
694             request.entryPointName = remappedEntryPoint.c_str();
695             switch (request.compiler) {
696                 case ShaderCompilationRequest::Compiler::DXC:
697                     DAWN_TRY_ASSIGN(compiledShader->compiledDXCShader,
698                                     CompileShaderDXC(dxcLibrary, dxcCompiler, request, hlslSource));
699                     break;
700                 case ShaderCompilationRequest::Compiler::FXC:
701                     DAWN_TRY_ASSIGN(compiledShader->compiledFXCShader,
702                                     CompileShaderFXC(functions, request, hlslSource));
703                     break;
704             }
705 
706             if (dumpShaders && request.compiler == ShaderCompilationRequest::Compiler::FXC) {
707                 std::ostringstream dumpedMsg;
708                 dumpedMsg << "/* FXC compile flags */ " << std::endl
709                           << CompileFlagsToStringFXC(request.compileFlags) << std::endl;
710 
711                 dumpedMsg << "/* Dumped disassembled DXBC */" << std::endl;
712 
713                 ComPtr<ID3DBlob> disassembly;
714                 if (FAILED(functions->d3dDisassemble(
715                         compiledShader->compiledFXCShader->GetBufferPointer(),
716                         compiledShader->compiledFXCShader->GetBufferSize(), 0, nullptr,
717                         &disassembly))) {
718                     dumpedMsg << "D3D disassemble failed" << std::endl;
719                 } else {
720                     dumpedMsg << reinterpret_cast<const char*>(disassembly->GetBufferPointer());
721                 }
722                 DumpShadersEmitLog(WGPULoggingType_Info, dumpedMsg.str().c_str());
723             }
724 
725             return {};
726         }
727 
728     }  // anonymous namespace
729 
730     // static
Create(Device * device,const ShaderModuleDescriptor * descriptor,ShaderModuleParseResult * parseResult)731     ResultOrError<Ref<ShaderModule>> ShaderModule::Create(Device* device,
732                                                           const ShaderModuleDescriptor* descriptor,
733                                                           ShaderModuleParseResult* parseResult) {
734         Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
735         DAWN_TRY(module->Initialize(parseResult));
736         return module;
737     }
738 
ShaderModule(Device * device,const ShaderModuleDescriptor * descriptor)739     ShaderModule::ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor)
740         : ShaderModuleBase(device, descriptor) {
741     }
742 
Initialize(ShaderModuleParseResult * parseResult)743     MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) {
744         ScopedTintICEHandler scopedICEHandler(GetDevice());
745         return InitializeBase(parseResult);
746     }
747 
Compile(const ProgrammableStage & programmableStage,SingleShaderStage stage,PipelineLayout * layout,uint32_t compileFlags)748     ResultOrError<CompiledShader> ShaderModule::Compile(const ProgrammableStage& programmableStage,
749                                                         SingleShaderStage stage,
750                                                         PipelineLayout* layout,
751                                                         uint32_t compileFlags) {
752         ASSERT(!IsError());
753         ScopedTintICEHandler scopedICEHandler(GetDevice());
754 
755         Device* device = ToBackend(GetDevice());
756 
757         CompiledShader compiledShader = {};
758 
759         tint::transform::Manager transformManager;
760         tint::transform::DataMap transformInputs;
761 
762         const tint::Program* program;
763         tint::Program programAsValue;
764         if (stage == SingleShaderStage::Vertex) {
765             transformManager.Add<tint::transform::FirstIndexOffset>();
766             transformInputs.Add<tint::transform::FirstIndexOffset::BindingPoint>(
767                 layout->GetFirstIndexOffsetShaderRegister(),
768                 layout->GetFirstIndexOffsetRegisterSpace());
769 
770             tint::transform::DataMap transformOutputs;
771             DAWN_TRY_ASSIGN(programAsValue,
772                             RunTransforms(&transformManager, GetTintProgram(), transformInputs,
773                                           &transformOutputs, nullptr));
774 
775             if (auto* data = transformOutputs.Get<tint::transform::FirstIndexOffset::Data>()) {
776                 // TODO(dawn:549): Consider adding this information to the pipeline cache once we
777                 // can store more than the shader blob in it.
778                 compiledShader.firstOffsetInfo.usesVertexIndex = data->has_vertex_index;
779                 if (compiledShader.firstOffsetInfo.usesVertexIndex) {
780                     compiledShader.firstOffsetInfo.vertexIndexOffset = data->first_vertex_offset;
781                 }
782                 compiledShader.firstOffsetInfo.usesInstanceIndex = data->has_instance_index;
783                 if (compiledShader.firstOffsetInfo.usesInstanceIndex) {
784                     compiledShader.firstOffsetInfo.instanceIndexOffset =
785                         data->first_instance_offset;
786                 }
787             }
788 
789             program = &programAsValue;
790         } else {
791             program = GetTintProgram();
792         }
793 
794         ShaderCompilationRequest request;
795         DAWN_TRY_ASSIGN(
796             request, ShaderCompilationRequest::Create(
797                          programmableStage.entryPoint.c_str(), stage, layout, compileFlags, device,
798                          program, GetEntryPoint(programmableStage.entryPoint), programmableStage));
799 
800         PersistentCacheKey shaderCacheKey;
801         DAWN_TRY_ASSIGN(shaderCacheKey, request.CreateCacheKey());
802 
803         DAWN_TRY_ASSIGN(
804             compiledShader.cachedShader,
805             device->GetPersistentCache()->GetOrCreate(
806                 shaderCacheKey, [&](auto doCache) -> MaybeError {
807                     DAWN_TRY(CompileShader(
808                         device->GetFunctions(),
809                         device->IsToggleEnabled(Toggle::UseDXC) ? device->GetDxcLibrary().Get()
810                                                                 : nullptr,
811                         device->IsToggleEnabled(Toggle::UseDXC) ? device->GetDxcCompiler().Get()
812                                                                 : nullptr,
813                         std::move(request), device->IsToggleEnabled(Toggle::DumpShaders),
814                         [&](WGPULoggingType loggingType, const char* message) {
815                             GetDevice()->EmitLog(loggingType, message);
816                         },
817                         &compiledShader));
818                     const D3D12_SHADER_BYTECODE shader = compiledShader.GetD3D12ShaderBytecode();
819                     doCache(shader.pShaderBytecode, shader.BytecodeLength);
820                     return {};
821                 }));
822 
823         return std::move(compiledShader);
824     }
825 
GetD3D12ShaderBytecode() const826     D3D12_SHADER_BYTECODE CompiledShader::GetD3D12ShaderBytecode() const {
827         if (cachedShader.buffer != nullptr) {
828             return {cachedShader.buffer.get(), cachedShader.bufferSize};
829         } else if (compiledFXCShader != nullptr) {
830             return {compiledFXCShader->GetBufferPointer(), compiledFXCShader->GetBufferSize()};
831         } else if (compiledDXCShader != nullptr) {
832             return {compiledDXCShader->GetBufferPointer(), compiledDXCShader->GetBufferSize()};
833         }
834         UNREACHABLE();
835         return {};
836     }
837 }}  // namespace dawn_native::d3d12
838