1 // Copyright 2017 The Dawn Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #include "dawn_native/d3d12/ShaderModuleD3D12.h" 16 17 #include "common/Assert.h" 18 #include "common/BitSetIterator.h" 19 #include "common/Log.h" 20 #include "common/WindowsUtils.h" 21 #include "dawn_native/Pipeline.h" 22 #include "dawn_native/TintUtils.h" 23 #include "dawn_native/d3d12/BindGroupLayoutD3D12.h" 24 #include "dawn_native/d3d12/D3D12Error.h" 25 #include "dawn_native/d3d12/DeviceD3D12.h" 26 #include "dawn_native/d3d12/PipelineLayoutD3D12.h" 27 #include "dawn_native/d3d12/PlatformFunctions.h" 28 #include "dawn_native/d3d12/UtilsD3D12.h" 29 30 #include <d3dcompiler.h> 31 32 #include <tint/tint.h> 33 #include <map> 34 #include <sstream> 35 #include <unordered_map> 36 37 namespace dawn_native { namespace d3d12 { 38 39 namespace { GetDXCompilerVersion(ComPtr<IDxcValidator> dxcValidator)40 ResultOrError<uint64_t> GetDXCompilerVersion(ComPtr<IDxcValidator> dxcValidator) { 41 ComPtr<IDxcVersionInfo> versionInfo; 42 DAWN_TRY(CheckHRESULT(dxcValidator.As(&versionInfo), 43 "D3D12 QueryInterface IDxcValidator to IDxcVersionInfo")); 44 45 uint32_t compilerMajor, compilerMinor; 46 DAWN_TRY(CheckHRESULT(versionInfo->GetVersion(&compilerMajor, &compilerMinor), 47 "IDxcVersionInfo::GetVersion")); 48 49 // Pack both into a single version number. 50 return (uint64_t(compilerMajor) << uint64_t(32)) + compilerMinor; 51 } 52 GetD3DCompilerVersion()53 uint64_t GetD3DCompilerVersion() { 54 return D3D_COMPILER_VERSION; 55 } 56 57 struct CompareBindingPoint { operator ()dawn_native::d3d12::__anone61f89010111::CompareBindingPoint58 constexpr bool operator()(const tint::transform::BindingPoint& lhs, 59 const tint::transform::BindingPoint& rhs) const { 60 if (lhs.group != rhs.group) { 61 return lhs.group < rhs.group; 62 } else { 63 return lhs.binding < rhs.binding; 64 } 65 } 66 }; 67 Serialize(std::stringstream & output,const tint::ast::Access & access)68 void Serialize(std::stringstream& output, const tint::ast::Access& access) { 69 output << access; 70 } 71 Serialize(std::stringstream & output,const tint::transform::BindingPoint & binding_point)72 void Serialize(std::stringstream& output, 73 const tint::transform::BindingPoint& binding_point) { 74 output << "(BindingPoint"; 75 output << " group=" << binding_point.group; 76 output << " binding=" << binding_point.binding; 77 output << ")"; 78 } 79 80 template <typename T, 81 typename = typename std::enable_if<std::is_fundamental<T>::value>::type> Serialize(std::stringstream & output,const T & val)82 void Serialize(std::stringstream& output, const T& val) { 83 output << val; 84 } 85 86 template <typename T> Serialize(std::stringstream & output,const std::unordered_map<tint::transform::BindingPoint,T> & map)87 void Serialize(std::stringstream& output, 88 const std::unordered_map<tint::transform::BindingPoint, T>& map) { 89 output << "(map"; 90 91 std::map<tint::transform::BindingPoint, T, CompareBindingPoint> sorted(map.begin(), 92 map.end()); 93 for (auto& entry : sorted) { 94 output << " "; 95 Serialize(output, entry.first); 96 output << "="; 97 Serialize(output, entry.second); 98 } 99 output << ")"; 100 } 101 Serialize(std::stringstream & output,const tint::writer::ArrayLengthFromUniformOptions & arrayLengthFromUniform)102 void Serialize(std::stringstream& output, 103 const tint::writer::ArrayLengthFromUniformOptions& arrayLengthFromUniform) { 104 output << "(ArrayLengthFromUniformOptions"; 105 output << " ubo_binding="; 106 Serialize(output, arrayLengthFromUniform.ubo_binding); 107 output << " bindpoint_to_size_index="; 108 Serialize(output, arrayLengthFromUniform.bindpoint_to_size_index); 109 output << ")"; 110 } 111 112 // 32 bit float has 7 decimal digits of precision so setting n to 8 should be enough FloatToStringWithPrecision(float v,std::streamsize n=8)113 std::string FloatToStringWithPrecision(float v, std::streamsize n = 8) { 114 std::ostringstream out; 115 out.precision(n); 116 out << std::fixed << v; 117 return out.str(); 118 } 119 GetHLSLValueString(EntryPointMetadata::OverridableConstant::Type dawnType,const OverridableConstantScalar * entry,double value=0)120 std::string GetHLSLValueString(EntryPointMetadata::OverridableConstant::Type dawnType, 121 const OverridableConstantScalar* entry, 122 double value = 0) { 123 switch (dawnType) { 124 case EntryPointMetadata::OverridableConstant::Type::Boolean: 125 return std::to_string(entry ? entry->b : static_cast<int32_t>(value)); 126 case EntryPointMetadata::OverridableConstant::Type::Float32: 127 return FloatToStringWithPrecision(entry ? entry->f32 128 : static_cast<float>(value)); 129 case EntryPointMetadata::OverridableConstant::Type::Int32: 130 return std::to_string(entry ? entry->i32 : static_cast<int32_t>(value)); 131 case EntryPointMetadata::OverridableConstant::Type::Uint32: 132 return std::to_string(entry ? entry->u32 : static_cast<uint32_t>(value)); 133 default: 134 UNREACHABLE(); 135 } 136 } 137 138 constexpr char kSpecConstantPrefix[] = "WGSL_SPEC_CONSTANT_"; 139 GetOverridableConstantsDefines(std::vector<std::pair<std::string,std::string>> * defineStrings,const PipelineConstantEntries * pipelineConstantEntries,const EntryPointMetadata::OverridableConstantsMap * shaderEntryPointConstants)140 void GetOverridableConstantsDefines( 141 std::vector<std::pair<std::string, std::string>>* defineStrings, 142 const PipelineConstantEntries* pipelineConstantEntries, 143 const EntryPointMetadata::OverridableConstantsMap* shaderEntryPointConstants) { 144 std::unordered_set<std::string> overriddenConstants; 145 146 // Set pipeline overridden values 147 for (const auto& pipelineConstant : *pipelineConstantEntries) { 148 const std::string& name = pipelineConstant.first; 149 double value = pipelineConstant.second; 150 151 overriddenConstants.insert(name); 152 153 // This is already validated so `name` must exist 154 const auto& moduleConstant = shaderEntryPointConstants->at(name); 155 156 defineStrings->emplace_back( 157 kSpecConstantPrefix + std::to_string(static_cast<int32_t>(moduleConstant.id)), 158 GetHLSLValueString(moduleConstant.type, nullptr, value)); 159 } 160 161 // Set shader initialized default values 162 for (const auto& iter : *shaderEntryPointConstants) { 163 const std::string& name = iter.first; 164 if (overriddenConstants.count(name) != 0) { 165 // This constant already has overridden value 166 continue; 167 } 168 169 const auto& moduleConstant = shaderEntryPointConstants->at(name); 170 171 // Uninitialized default values are okay since they ar only defined to pass 172 // compilation but not used 173 defineStrings->emplace_back( 174 kSpecConstantPrefix + std::to_string(static_cast<int32_t>(moduleConstant.id)), 175 GetHLSLValueString(moduleConstant.type, &moduleConstant.defaultValue)); 176 } 177 } 178 179 // The inputs to a shader compilation. These have been intentionally isolated from the 180 // device to help ensure that the pipeline cache key contains all inputs for compilation. 181 struct ShaderCompilationRequest { 182 enum Compiler { FXC, DXC }; 183 184 // Common inputs 185 Compiler compiler; 186 const tint::Program* program; 187 const char* entryPointName; 188 SingleShaderStage stage; 189 uint32_t compileFlags; 190 bool disableSymbolRenaming; 191 tint::transform::BindingRemapper::BindingPoints remappedBindingPoints; 192 tint::transform::BindingRemapper::AccessControls remappedAccessControls; 193 bool isRobustnessEnabled; 194 bool usesNumWorkgroups; 195 uint32_t numWorkgroupsRegisterSpace; 196 uint32_t numWorkgroupsShaderRegister; 197 tint::writer::ArrayLengthFromUniformOptions arrayLengthFromUniform; 198 std::vector<std::pair<std::string, std::string>> defineStrings; 199 200 // FXC/DXC common inputs 201 bool disableWorkgroupInit; 202 203 // FXC inputs 204 uint64_t fxcVersion; 205 206 // DXC inputs 207 uint64_t dxcVersion; 208 const D3D12DeviceInfo* deviceInfo; 209 bool hasShaderFloat16Feature; 210 Createdawn_native::d3d12::__anone61f89010111::ShaderCompilationRequest211 static ResultOrError<ShaderCompilationRequest> Create( 212 const char* entryPointName, 213 SingleShaderStage stage, 214 const PipelineLayout* layout, 215 uint32_t compileFlags, 216 const Device* device, 217 const tint::Program* program, 218 const EntryPointMetadata& entryPoint, 219 const ProgrammableStage& programmableStage) { 220 Compiler compiler; 221 uint64_t dxcVersion = 0; 222 if (device->IsToggleEnabled(Toggle::UseDXC)) { 223 compiler = Compiler::DXC; 224 DAWN_TRY_ASSIGN(dxcVersion, GetDXCompilerVersion(device->GetDxcValidator())); 225 } else { 226 compiler = Compiler::FXC; 227 } 228 229 using tint::transform::BindingPoint; 230 using tint::transform::BindingRemapper; 231 232 BindingRemapper::BindingPoints remappedBindingPoints; 233 BindingRemapper::AccessControls remappedAccessControls; 234 235 tint::writer::ArrayLengthFromUniformOptions arrayLengthFromUniform; 236 arrayLengthFromUniform.ubo_binding = { 237 layout->GetDynamicStorageBufferLengthsRegisterSpace(), 238 layout->GetDynamicStorageBufferLengthsShaderRegister()}; 239 240 const BindingInfoArray& moduleBindingInfo = entryPoint.bindings; 241 for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) { 242 const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group)); 243 const auto& groupBindingInfo = moduleBindingInfo[group]; 244 245 // d3d12::BindGroupLayout packs the bindings per HLSL register-space. We modify 246 // the Tint AST to make the "bindings" decoration match the offset chosen by 247 // d3d12::BindGroupLayout so that Tint produces HLSL with the correct registers 248 // assigned to each interface variable. 249 for (const auto& it : groupBindingInfo) { 250 BindingNumber binding = it.first; 251 auto const& bindingInfo = it.second; 252 BindingIndex bindingIndex = bgl->GetBindingIndex(binding); 253 BindingPoint srcBindingPoint{static_cast<uint32_t>(group), 254 static_cast<uint32_t>(binding)}; 255 BindingPoint dstBindingPoint{static_cast<uint32_t>(group), 256 bgl->GetShaderRegister(bindingIndex)}; 257 if (srcBindingPoint != dstBindingPoint) { 258 remappedBindingPoints.emplace(srcBindingPoint, dstBindingPoint); 259 } 260 261 // Declaring a read-only storage buffer in HLSL but specifying a storage 262 // buffer in the BGL produces the wrong output. Force read-only storage 263 // buffer bindings to be treated as UAV instead of SRV. Internal storage 264 // buffer is a storage buffer used in the internal pipeline. 265 const bool forceStorageBufferAsUAV = 266 (bindingInfo.buffer.type == wgpu::BufferBindingType::ReadOnlyStorage && 267 (bgl->GetBindingInfo(bindingIndex).buffer.type == 268 wgpu::BufferBindingType::Storage || 269 bgl->GetBindingInfo(bindingIndex).buffer.type == 270 kInternalStorageBufferBinding)); 271 if (forceStorageBufferAsUAV) { 272 remappedAccessControls.emplace(srcBindingPoint, 273 tint::ast::Access::kReadWrite); 274 } 275 } 276 277 // Add arrayLengthFromUniform options 278 { 279 for (const auto& bindingAndRegisterOffset : 280 layout->GetDynamicStorageBufferLengthInfo()[group] 281 .bindingAndRegisterOffsets) { 282 BindingNumber binding = bindingAndRegisterOffset.binding; 283 uint32_t registerOffset = bindingAndRegisterOffset.registerOffset; 284 285 BindingPoint bindingPoint{static_cast<uint32_t>(group), 286 static_cast<uint32_t>(binding)}; 287 // Get the renamed binding point if it was remapped. 288 auto it = remappedBindingPoints.find(bindingPoint); 289 if (it != remappedBindingPoints.end()) { 290 bindingPoint = it->second; 291 } 292 293 arrayLengthFromUniform.bindpoint_to_size_index.emplace(bindingPoint, 294 registerOffset); 295 } 296 } 297 } 298 299 ShaderCompilationRequest request; 300 request.compiler = compiler; 301 request.program = program; 302 request.entryPointName = entryPointName; 303 request.stage = stage; 304 request.compileFlags = compileFlags; 305 request.disableSymbolRenaming = 306 device->IsToggleEnabled(Toggle::DisableSymbolRenaming); 307 request.remappedBindingPoints = std::move(remappedBindingPoints); 308 request.remappedAccessControls = std::move(remappedAccessControls); 309 request.isRobustnessEnabled = device->IsRobustnessEnabled(); 310 request.disableWorkgroupInit = 311 device->IsToggleEnabled(Toggle::DisableWorkgroupInit); 312 request.usesNumWorkgroups = entryPoint.usesNumWorkgroups; 313 request.numWorkgroupsShaderRegister = layout->GetNumWorkgroupsShaderRegister(); 314 request.numWorkgroupsRegisterSpace = layout->GetNumWorkgroupsRegisterSpace(); 315 request.arrayLengthFromUniform = std::move(arrayLengthFromUniform); 316 request.fxcVersion = compiler == Compiler::FXC ? GetD3DCompilerVersion() : 0; 317 request.dxcVersion = compiler == Compiler::DXC ? dxcVersion : 0; 318 request.deviceInfo = &device->GetDeviceInfo(); 319 request.hasShaderFloat16Feature = device->IsFeatureEnabled(Feature::ShaderFloat16); 320 321 GetOverridableConstantsDefines( 322 &request.defineStrings, &programmableStage.constants, 323 &programmableStage.module->GetEntryPoint(programmableStage.entryPoint) 324 .overridableConstants); 325 326 return std::move(request); 327 } 328 CreateCacheKeydawn_native::d3d12::__anone61f89010111::ShaderCompilationRequest329 ResultOrError<PersistentCacheKey> CreateCacheKey() const { 330 // Generate the WGSL from the Tint program so it's normalized. 331 // TODO(tint:1180): Consider using a binary serialization of the tint AST for a more 332 // compact representation. 333 auto result = tint::writer::wgsl::Generate(program, tint::writer::wgsl::Options{}); 334 if (!result.success) { 335 std::ostringstream errorStream; 336 errorStream << "Tint WGSL failure:" << std::endl; 337 errorStream << "Generator: " << result.error << std::endl; 338 return DAWN_INTERNAL_ERROR(errorStream.str().c_str()); 339 } 340 341 std::stringstream stream; 342 343 // Prefix the key with the type to avoid collisions from another type that could 344 // have the same key. 345 stream << static_cast<uint32_t>(PersistentKeyType::Shader); 346 stream << "\n"; 347 348 stream << result.wgsl.length(); 349 stream << "\n"; 350 351 stream << result.wgsl; 352 stream << "\n"; 353 354 stream << "(ShaderCompilationRequest"; 355 stream << " compiler=" << compiler; 356 stream << " entryPointName=" << entryPointName; 357 stream << " stage=" << uint32_t(stage); 358 stream << " compileFlags=" << compileFlags; 359 stream << " disableSymbolRenaming=" << disableSymbolRenaming; 360 361 stream << " remappedBindingPoints="; 362 Serialize(stream, remappedBindingPoints); 363 364 stream << " remappedAccessControls="; 365 Serialize(stream, remappedAccessControls); 366 367 stream << " useNumWorkgroups=" << usesNumWorkgroups; 368 stream << " numWorkgroupsRegisterSpace=" << numWorkgroupsRegisterSpace; 369 stream << " numWorkgroupsShaderRegister=" << numWorkgroupsShaderRegister; 370 371 stream << " arrayLengthFromUniform="; 372 Serialize(stream, arrayLengthFromUniform); 373 374 stream << " shaderModel=" << deviceInfo->shaderModel; 375 stream << " disableWorkgroupInit=" << disableWorkgroupInit; 376 stream << " isRobustnessEnabled=" << isRobustnessEnabled; 377 stream << " fxcVersion=" << fxcVersion; 378 stream << " dxcVersion=" << dxcVersion; 379 stream << " hasShaderFloat16Feature=" << hasShaderFloat16Feature; 380 381 stream << " defines={"; 382 for (const auto& it : defineStrings) { 383 stream << " <" << it.first << "," << it.second << ">"; 384 } 385 stream << " }"; 386 387 stream << ")"; 388 stream << "\n"; 389 390 return PersistentCacheKey(std::istreambuf_iterator<char>{stream}, 391 std::istreambuf_iterator<char>{}); 392 } 393 }; 394 GetDXCArguments(uint32_t compileFlags,bool enable16BitTypes)395 std::vector<const wchar_t*> GetDXCArguments(uint32_t compileFlags, bool enable16BitTypes) { 396 std::vector<const wchar_t*> arguments; 397 if (compileFlags & D3DCOMPILE_ENABLE_BACKWARDS_COMPATIBILITY) { 398 arguments.push_back(L"/Gec"); 399 } 400 if (compileFlags & D3DCOMPILE_IEEE_STRICTNESS) { 401 arguments.push_back(L"/Gis"); 402 } 403 constexpr uint32_t d3dCompileFlagsBits = D3DCOMPILE_OPTIMIZATION_LEVEL2; 404 if (compileFlags & d3dCompileFlagsBits) { 405 switch (compileFlags & D3DCOMPILE_OPTIMIZATION_LEVEL2) { 406 case D3DCOMPILE_OPTIMIZATION_LEVEL0: 407 arguments.push_back(L"/O0"); 408 break; 409 case D3DCOMPILE_OPTIMIZATION_LEVEL2: 410 arguments.push_back(L"/O2"); 411 break; 412 case D3DCOMPILE_OPTIMIZATION_LEVEL3: 413 arguments.push_back(L"/O3"); 414 break; 415 } 416 } 417 if (compileFlags & D3DCOMPILE_DEBUG) { 418 arguments.push_back(L"/Zi"); 419 } 420 if (compileFlags & D3DCOMPILE_PACK_MATRIX_ROW_MAJOR) { 421 arguments.push_back(L"/Zpr"); 422 } 423 if (compileFlags & D3DCOMPILE_PACK_MATRIX_COLUMN_MAJOR) { 424 arguments.push_back(L"/Zpc"); 425 } 426 if (compileFlags & D3DCOMPILE_AVOID_FLOW_CONTROL) { 427 arguments.push_back(L"/Gfa"); 428 } 429 if (compileFlags & D3DCOMPILE_PREFER_FLOW_CONTROL) { 430 arguments.push_back(L"/Gfp"); 431 } 432 if (compileFlags & D3DCOMPILE_RESOURCES_MAY_ALIAS) { 433 arguments.push_back(L"/res_may_alias"); 434 } 435 436 if (enable16BitTypes) { 437 // enable-16bit-types are only allowed in -HV 2018 (default) 438 arguments.push_back(L"/enable-16bit-types"); 439 } 440 441 arguments.push_back(L"-HV"); 442 arguments.push_back(L"2018"); 443 444 return arguments; 445 } 446 CompileShaderDXC(IDxcLibrary * dxcLibrary,IDxcCompiler * dxcCompiler,const ShaderCompilationRequest & request,const std::string & hlslSource)447 ResultOrError<ComPtr<IDxcBlob>> CompileShaderDXC(IDxcLibrary* dxcLibrary, 448 IDxcCompiler* dxcCompiler, 449 const ShaderCompilationRequest& request, 450 const std::string& hlslSource) { 451 ComPtr<IDxcBlobEncoding> sourceBlob; 452 DAWN_TRY( 453 CheckHRESULT(dxcLibrary->CreateBlobWithEncodingOnHeapCopy( 454 hlslSource.c_str(), hlslSource.length(), CP_UTF8, &sourceBlob), 455 "DXC create blob")); 456 457 std::wstring entryPointW; 458 DAWN_TRY_ASSIGN(entryPointW, ConvertStringToWstring(request.entryPointName)); 459 460 std::vector<const wchar_t*> arguments = 461 GetDXCArguments(request.compileFlags, request.hasShaderFloat16Feature); 462 463 // Build defines for overridable constants 464 std::vector<std::pair<std::wstring, std::wstring>> defineStrings; 465 defineStrings.reserve(request.defineStrings.size()); 466 for (const auto& it : request.defineStrings) { 467 defineStrings.emplace_back(UTF8ToWStr(it.first.c_str()), 468 UTF8ToWStr(it.second.c_str())); 469 } 470 471 std::vector<DxcDefine> dxcDefines; 472 dxcDefines.reserve(defineStrings.size()); 473 for (const auto& d : defineStrings) { 474 dxcDefines.push_back({d.first.c_str(), d.second.c_str()}); 475 } 476 477 ComPtr<IDxcOperationResult> result; 478 DAWN_TRY(CheckHRESULT( 479 dxcCompiler->Compile(sourceBlob.Get(), nullptr, entryPointW.c_str(), 480 request.deviceInfo->shaderProfiles[request.stage].c_str(), 481 arguments.data(), arguments.size(), dxcDefines.data(), 482 dxcDefines.size(), nullptr, &result), 483 "DXC compile")); 484 485 HRESULT hr; 486 DAWN_TRY(CheckHRESULT(result->GetStatus(&hr), "DXC get status")); 487 488 if (FAILED(hr)) { 489 ComPtr<IDxcBlobEncoding> errors; 490 DAWN_TRY(CheckHRESULT(result->GetErrorBuffer(&errors), "DXC get error buffer")); 491 492 return DAWN_FORMAT_VALIDATION_ERROR("DXC compile failed with: %s", 493 static_cast<char*>(errors->GetBufferPointer())); 494 } 495 496 ComPtr<IDxcBlob> compiledShader; 497 DAWN_TRY(CheckHRESULT(result->GetResult(&compiledShader), "DXC get result")); 498 return std::move(compiledShader); 499 } 500 CompileFlagsToStringFXC(uint32_t compileFlags)501 std::string CompileFlagsToStringFXC(uint32_t compileFlags) { 502 struct Flag { 503 uint32_t value; 504 const char* name; 505 }; 506 constexpr Flag flags[] = { 507 // Populated from d3dcompiler.h 508 #define F(f) Flag{f, #f} 509 F(D3DCOMPILE_DEBUG), 510 F(D3DCOMPILE_SKIP_VALIDATION), 511 F(D3DCOMPILE_SKIP_OPTIMIZATION), 512 F(D3DCOMPILE_PACK_MATRIX_ROW_MAJOR), 513 F(D3DCOMPILE_PACK_MATRIX_COLUMN_MAJOR), 514 F(D3DCOMPILE_PARTIAL_PRECISION), 515 F(D3DCOMPILE_FORCE_VS_SOFTWARE_NO_OPT), 516 F(D3DCOMPILE_FORCE_PS_SOFTWARE_NO_OPT), 517 F(D3DCOMPILE_NO_PRESHADER), 518 F(D3DCOMPILE_AVOID_FLOW_CONTROL), 519 F(D3DCOMPILE_PREFER_FLOW_CONTROL), 520 F(D3DCOMPILE_ENABLE_STRICTNESS), 521 F(D3DCOMPILE_ENABLE_BACKWARDS_COMPATIBILITY), 522 F(D3DCOMPILE_IEEE_STRICTNESS), 523 F(D3DCOMPILE_RESERVED16), 524 F(D3DCOMPILE_RESERVED17), 525 F(D3DCOMPILE_WARNINGS_ARE_ERRORS), 526 F(D3DCOMPILE_RESOURCES_MAY_ALIAS), 527 F(D3DCOMPILE_ENABLE_UNBOUNDED_DESCRIPTOR_TABLES), 528 F(D3DCOMPILE_ALL_RESOURCES_BOUND), 529 F(D3DCOMPILE_DEBUG_NAME_FOR_SOURCE), 530 F(D3DCOMPILE_DEBUG_NAME_FOR_BINARY), 531 #undef F 532 }; 533 534 std::string result; 535 for (const Flag& f : flags) { 536 if ((compileFlags & f.value) != 0) { 537 result += f.name + std::string("\n"); 538 } 539 } 540 541 // Optimization level must be handled separately as two bits are used, and the values 542 // don't map neatly to 0-3. 543 constexpr uint32_t d3dCompileFlagsBits = D3DCOMPILE_OPTIMIZATION_LEVEL2; 544 switch (compileFlags & d3dCompileFlagsBits) { 545 case D3DCOMPILE_OPTIMIZATION_LEVEL0: 546 result += "D3DCOMPILE_OPTIMIZATION_LEVEL0"; 547 break; 548 case D3DCOMPILE_OPTIMIZATION_LEVEL1: 549 result += "D3DCOMPILE_OPTIMIZATION_LEVEL1"; 550 break; 551 case D3DCOMPILE_OPTIMIZATION_LEVEL2: 552 result += "D3DCOMPILE_OPTIMIZATION_LEVEL2"; 553 break; 554 case D3DCOMPILE_OPTIMIZATION_LEVEL3: 555 result += "D3DCOMPILE_OPTIMIZATION_LEVEL3"; 556 break; 557 } 558 result += std::string("\n"); 559 560 return result; 561 } 562 CompileShaderFXC(const PlatformFunctions * functions,const ShaderCompilationRequest & request,const std::string & hlslSource)563 ResultOrError<ComPtr<ID3DBlob>> CompileShaderFXC(const PlatformFunctions* functions, 564 const ShaderCompilationRequest& request, 565 const std::string& hlslSource) { 566 const char* targetProfile = nullptr; 567 switch (request.stage) { 568 case SingleShaderStage::Vertex: 569 targetProfile = "vs_5_1"; 570 break; 571 case SingleShaderStage::Fragment: 572 targetProfile = "ps_5_1"; 573 break; 574 case SingleShaderStage::Compute: 575 targetProfile = "cs_5_1"; 576 break; 577 } 578 579 ComPtr<ID3DBlob> compiledShader; 580 ComPtr<ID3DBlob> errors; 581 582 // Build defines for overridable constants 583 const D3D_SHADER_MACRO* pDefines = nullptr; 584 std::vector<D3D_SHADER_MACRO> fxcDefines; 585 if (request.defineStrings.size() > 0) { 586 fxcDefines.reserve(request.defineStrings.size() + 1); 587 for (const auto& d : request.defineStrings) { 588 fxcDefines.push_back({d.first.c_str(), d.second.c_str()}); 589 } 590 // d3dCompile D3D_SHADER_MACRO* pDefines is a nullptr terminated array 591 fxcDefines.push_back({nullptr, nullptr}); 592 pDefines = fxcDefines.data(); 593 } 594 595 DAWN_INVALID_IF(FAILED(functions->d3dCompile( 596 hlslSource.c_str(), hlslSource.length(), nullptr, pDefines, nullptr, 597 request.entryPointName, targetProfile, request.compileFlags, 0, 598 &compiledShader, &errors)), 599 "D3D compile failed with: %s", 600 static_cast<char*>(errors->GetBufferPointer())); 601 602 return std::move(compiledShader); 603 } 604 TranslateToHLSL(const ShaderCompilationRequest & request,std::string * remappedEntryPointName)605 ResultOrError<std::string> TranslateToHLSL(const ShaderCompilationRequest& request, 606 std::string* remappedEntryPointName) { 607 std::ostringstream errorStream; 608 errorStream << "Tint HLSL failure:" << std::endl; 609 610 tint::transform::Manager transformManager; 611 tint::transform::DataMap transformInputs; 612 613 if (request.isRobustnessEnabled) { 614 transformManager.Add<tint::transform::Robustness>(); 615 } 616 617 transformManager.Add<tint::transform::BindingRemapper>(); 618 619 transformManager.Add<tint::transform::SingleEntryPoint>(); 620 transformInputs.Add<tint::transform::SingleEntryPoint::Config>(request.entryPointName); 621 622 transformManager.Add<tint::transform::Renamer>(); 623 624 if (request.disableSymbolRenaming) { 625 // We still need to rename HLSL reserved keywords 626 transformInputs.Add<tint::transform::Renamer::Config>( 627 tint::transform::Renamer::Target::kHlslKeywords); 628 } 629 630 // D3D12 registers like `t3` and `c3` have the same bindingOffset number in 631 // the remapping but should not be considered a collision because they have 632 // different types. 633 const bool mayCollide = true; 634 transformInputs.Add<tint::transform::BindingRemapper::Remappings>( 635 std::move(request.remappedBindingPoints), std::move(request.remappedAccessControls), 636 mayCollide); 637 638 tint::Program transformedProgram; 639 tint::transform::DataMap transformOutputs; 640 DAWN_TRY_ASSIGN(transformedProgram, 641 RunTransforms(&transformManager, request.program, transformInputs, 642 &transformOutputs, nullptr)); 643 644 if (auto* data = transformOutputs.Get<tint::transform::Renamer::Data>()) { 645 auto it = data->remappings.find(request.entryPointName); 646 if (it != data->remappings.end()) { 647 *remappedEntryPointName = it->second; 648 } else { 649 DAWN_INVALID_IF(!request.disableSymbolRenaming, 650 "Could not find remapped name for entry point."); 651 652 *remappedEntryPointName = request.entryPointName; 653 } 654 } else { 655 return DAWN_FORMAT_VALIDATION_ERROR("Transform output missing renamer data."); 656 } 657 658 tint::writer::hlsl::Options options; 659 options.disable_workgroup_init = request.disableWorkgroupInit; 660 if (request.usesNumWorkgroups) { 661 options.root_constant_binding_point.group = request.numWorkgroupsRegisterSpace; 662 options.root_constant_binding_point.binding = request.numWorkgroupsShaderRegister; 663 } 664 // TODO(dawn:549): HLSL generation outputs the indices into the 665 // array_length_from_uniform buffer that were actually used. When the blob cache can 666 // store more than compiled shaders, we should reflect these used indices and store 667 // them as well. This would allow us to only upload root constants that are actually 668 // read by the shader. 669 options.array_length_from_uniform = request.arrayLengthFromUniform; 670 auto result = tint::writer::hlsl::Generate(&transformedProgram, options); 671 DAWN_INVALID_IF(!result.success, "An error occured while generating HLSL: %s", 672 result.error); 673 674 return std::move(result.hlsl); 675 } 676 677 template <typename F> CompileShader(const PlatformFunctions * functions,IDxcLibrary * dxcLibrary,IDxcCompiler * dxcCompiler,ShaderCompilationRequest && request,bool dumpShaders,F && DumpShadersEmitLog,CompiledShader * compiledShader)678 MaybeError CompileShader(const PlatformFunctions* functions, 679 IDxcLibrary* dxcLibrary, 680 IDxcCompiler* dxcCompiler, 681 ShaderCompilationRequest&& request, 682 bool dumpShaders, 683 F&& DumpShadersEmitLog, 684 CompiledShader* compiledShader) { 685 // Compile the source shader to HLSL. 686 std::string hlslSource; 687 std::string remappedEntryPoint; 688 DAWN_TRY_ASSIGN(hlslSource, TranslateToHLSL(request, &remappedEntryPoint)); 689 if (dumpShaders) { 690 std::ostringstream dumpedMsg; 691 dumpedMsg << "/* Dumped generated HLSL */" << std::endl << hlslSource; 692 DumpShadersEmitLog(WGPULoggingType_Info, dumpedMsg.str().c_str()); 693 } 694 request.entryPointName = remappedEntryPoint.c_str(); 695 switch (request.compiler) { 696 case ShaderCompilationRequest::Compiler::DXC: 697 DAWN_TRY_ASSIGN(compiledShader->compiledDXCShader, 698 CompileShaderDXC(dxcLibrary, dxcCompiler, request, hlslSource)); 699 break; 700 case ShaderCompilationRequest::Compiler::FXC: 701 DAWN_TRY_ASSIGN(compiledShader->compiledFXCShader, 702 CompileShaderFXC(functions, request, hlslSource)); 703 break; 704 } 705 706 if (dumpShaders && request.compiler == ShaderCompilationRequest::Compiler::FXC) { 707 std::ostringstream dumpedMsg; 708 dumpedMsg << "/* FXC compile flags */ " << std::endl 709 << CompileFlagsToStringFXC(request.compileFlags) << std::endl; 710 711 dumpedMsg << "/* Dumped disassembled DXBC */" << std::endl; 712 713 ComPtr<ID3DBlob> disassembly; 714 if (FAILED(functions->d3dDisassemble( 715 compiledShader->compiledFXCShader->GetBufferPointer(), 716 compiledShader->compiledFXCShader->GetBufferSize(), 0, nullptr, 717 &disassembly))) { 718 dumpedMsg << "D3D disassemble failed" << std::endl; 719 } else { 720 dumpedMsg << reinterpret_cast<const char*>(disassembly->GetBufferPointer()); 721 } 722 DumpShadersEmitLog(WGPULoggingType_Info, dumpedMsg.str().c_str()); 723 } 724 725 return {}; 726 } 727 728 } // anonymous namespace 729 730 // static Create(Device * device,const ShaderModuleDescriptor * descriptor,ShaderModuleParseResult * parseResult)731 ResultOrError<Ref<ShaderModule>> ShaderModule::Create(Device* device, 732 const ShaderModuleDescriptor* descriptor, 733 ShaderModuleParseResult* parseResult) { 734 Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor)); 735 DAWN_TRY(module->Initialize(parseResult)); 736 return module; 737 } 738 ShaderModule(Device * device,const ShaderModuleDescriptor * descriptor)739 ShaderModule::ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor) 740 : ShaderModuleBase(device, descriptor) { 741 } 742 Initialize(ShaderModuleParseResult * parseResult)743 MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) { 744 ScopedTintICEHandler scopedICEHandler(GetDevice()); 745 return InitializeBase(parseResult); 746 } 747 Compile(const ProgrammableStage & programmableStage,SingleShaderStage stage,PipelineLayout * layout,uint32_t compileFlags)748 ResultOrError<CompiledShader> ShaderModule::Compile(const ProgrammableStage& programmableStage, 749 SingleShaderStage stage, 750 PipelineLayout* layout, 751 uint32_t compileFlags) { 752 ASSERT(!IsError()); 753 ScopedTintICEHandler scopedICEHandler(GetDevice()); 754 755 Device* device = ToBackend(GetDevice()); 756 757 CompiledShader compiledShader = {}; 758 759 tint::transform::Manager transformManager; 760 tint::transform::DataMap transformInputs; 761 762 const tint::Program* program; 763 tint::Program programAsValue; 764 if (stage == SingleShaderStage::Vertex) { 765 transformManager.Add<tint::transform::FirstIndexOffset>(); 766 transformInputs.Add<tint::transform::FirstIndexOffset::BindingPoint>( 767 layout->GetFirstIndexOffsetShaderRegister(), 768 layout->GetFirstIndexOffsetRegisterSpace()); 769 770 tint::transform::DataMap transformOutputs; 771 DAWN_TRY_ASSIGN(programAsValue, 772 RunTransforms(&transformManager, GetTintProgram(), transformInputs, 773 &transformOutputs, nullptr)); 774 775 if (auto* data = transformOutputs.Get<tint::transform::FirstIndexOffset::Data>()) { 776 // TODO(dawn:549): Consider adding this information to the pipeline cache once we 777 // can store more than the shader blob in it. 778 compiledShader.firstOffsetInfo.usesVertexIndex = data->has_vertex_index; 779 if (compiledShader.firstOffsetInfo.usesVertexIndex) { 780 compiledShader.firstOffsetInfo.vertexIndexOffset = data->first_vertex_offset; 781 } 782 compiledShader.firstOffsetInfo.usesInstanceIndex = data->has_instance_index; 783 if (compiledShader.firstOffsetInfo.usesInstanceIndex) { 784 compiledShader.firstOffsetInfo.instanceIndexOffset = 785 data->first_instance_offset; 786 } 787 } 788 789 program = &programAsValue; 790 } else { 791 program = GetTintProgram(); 792 } 793 794 ShaderCompilationRequest request; 795 DAWN_TRY_ASSIGN( 796 request, ShaderCompilationRequest::Create( 797 programmableStage.entryPoint.c_str(), stage, layout, compileFlags, device, 798 program, GetEntryPoint(programmableStage.entryPoint), programmableStage)); 799 800 PersistentCacheKey shaderCacheKey; 801 DAWN_TRY_ASSIGN(shaderCacheKey, request.CreateCacheKey()); 802 803 DAWN_TRY_ASSIGN( 804 compiledShader.cachedShader, 805 device->GetPersistentCache()->GetOrCreate( 806 shaderCacheKey, [&](auto doCache) -> MaybeError { 807 DAWN_TRY(CompileShader( 808 device->GetFunctions(), 809 device->IsToggleEnabled(Toggle::UseDXC) ? device->GetDxcLibrary().Get() 810 : nullptr, 811 device->IsToggleEnabled(Toggle::UseDXC) ? device->GetDxcCompiler().Get() 812 : nullptr, 813 std::move(request), device->IsToggleEnabled(Toggle::DumpShaders), 814 [&](WGPULoggingType loggingType, const char* message) { 815 GetDevice()->EmitLog(loggingType, message); 816 }, 817 &compiledShader)); 818 const D3D12_SHADER_BYTECODE shader = compiledShader.GetD3D12ShaderBytecode(); 819 doCache(shader.pShaderBytecode, shader.BytecodeLength); 820 return {}; 821 })); 822 823 return std::move(compiledShader); 824 } 825 GetD3D12ShaderBytecode() const826 D3D12_SHADER_BYTECODE CompiledShader::GetD3D12ShaderBytecode() const { 827 if (cachedShader.buffer != nullptr) { 828 return {cachedShader.buffer.get(), cachedShader.bufferSize}; 829 } else if (compiledFXCShader != nullptr) { 830 return {compiledFXCShader->GetBufferPointer(), compiledFXCShader->GetBufferSize()}; 831 } else if (compiledDXCShader != nullptr) { 832 return {compiledDXCShader->GetBufferPointer(), compiledDXCShader->GetBufferSize()}; 833 } 834 UNREACHABLE(); 835 return {}; 836 } 837 }} // namespace dawn_native::d3d12 838