• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2018 The Dawn Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "dawn_native/vulkan/ShaderModuleVk.h"
16 
17 #include "dawn_native/SpirvValidation.h"
18 #include "dawn_native/TintUtils.h"
19 #include "dawn_native/vulkan/BindGroupLayoutVk.h"
20 #include "dawn_native/vulkan/DeviceVk.h"
21 #include "dawn_native/vulkan/FencedDeleter.h"
22 #include "dawn_native/vulkan/PipelineLayoutVk.h"
23 #include "dawn_native/vulkan/UtilsVulkan.h"
24 #include "dawn_native/vulkan/VulkanError.h"
25 
26 #include <tint/tint.h>
27 #include <spirv-tools/libspirv.hpp>
28 
29 namespace dawn_native { namespace vulkan {
30 
ConcurrentTransformedShaderModuleCache(Device * device)31     ShaderModule::ConcurrentTransformedShaderModuleCache::ConcurrentTransformedShaderModuleCache(
32         Device* device)
33         : mDevice(device) {
34     }
35 
36     ShaderModule::ConcurrentTransformedShaderModuleCache::
~ConcurrentTransformedShaderModuleCache()37         ~ConcurrentTransformedShaderModuleCache() {
38         std::lock_guard<std::mutex> lock(mMutex);
39         for (const auto& iter : mTransformedShaderModuleCache) {
40             mDevice->GetFencedDeleter()->DeleteWhenUnused(iter.second);
41         }
42     }
43 
FindShaderModule(const PipelineLayoutEntryPointPair & key)44     VkShaderModule ShaderModule::ConcurrentTransformedShaderModuleCache::FindShaderModule(
45         const PipelineLayoutEntryPointPair& key) {
46         std::lock_guard<std::mutex> lock(mMutex);
47         auto iter = mTransformedShaderModuleCache.find(key);
48         if (iter != mTransformedShaderModuleCache.end()) {
49             auto cached = iter->second;
50             return cached;
51         }
52         return VK_NULL_HANDLE;
53     }
54 
AddOrGetCachedShaderModule(const PipelineLayoutEntryPointPair & key,VkShaderModule value)55     VkShaderModule ShaderModule::ConcurrentTransformedShaderModuleCache::AddOrGetCachedShaderModule(
56         const PipelineLayoutEntryPointPair& key,
57         VkShaderModule value) {
58         ASSERT(value != VK_NULL_HANDLE);
59         std::lock_guard<std::mutex> lock(mMutex);
60         auto iter = mTransformedShaderModuleCache.find(key);
61         if (iter == mTransformedShaderModuleCache.end()) {
62             mTransformedShaderModuleCache.emplace(key, value);
63             return value;
64         } else {
65             mDevice->GetFencedDeleter()->DeleteWhenUnused(value);
66             return iter->second;
67         }
68     }
69 
70     // static
Create(Device * device,const ShaderModuleDescriptor * descriptor,ShaderModuleParseResult * parseResult)71     ResultOrError<Ref<ShaderModule>> ShaderModule::Create(Device* device,
72                                                           const ShaderModuleDescriptor* descriptor,
73                                                           ShaderModuleParseResult* parseResult) {
74         Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
75         DAWN_TRY(module->Initialize(parseResult));
76         return module;
77     }
78 
ShaderModule(Device * device,const ShaderModuleDescriptor * descriptor)79     ShaderModule::ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor)
80         : ShaderModuleBase(device, descriptor),
81           mTransformedShaderModuleCache(
82               std::make_unique<ConcurrentTransformedShaderModuleCache>(device)) {
83     }
84 
Initialize(ShaderModuleParseResult * parseResult)85     MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) {
86         if (GetDevice()->IsRobustnessEnabled()) {
87             ScopedTintICEHandler scopedICEHandler(GetDevice());
88 
89             tint::transform::Robustness robustness;
90             tint::transform::DataMap transformInputs;
91 
92             tint::Program program;
93             DAWN_TRY_ASSIGN(program, RunTransforms(&robustness, parseResult->tintProgram.get(),
94                                                    transformInputs, nullptr, nullptr));
95             // Rather than use a new ParseResult object, we just reuse the original parseResult
96             parseResult->tintProgram = std::make_unique<tint::Program>(std::move(program));
97         }
98 
99         return InitializeBase(parseResult);
100     }
101 
DestroyImpl()102     void ShaderModule::DestroyImpl() {
103         ShaderModuleBase::DestroyImpl();
104         // Remove reference to internal cache to trigger cleanup.
105         mTransformedShaderModuleCache = nullptr;
106     }
107 
108     ShaderModule::~ShaderModule() = default;
109 
GetTransformedModuleHandle(const char * entryPointName,PipelineLayout * layout)110     ResultOrError<VkShaderModule> ShaderModule::GetTransformedModuleHandle(
111         const char* entryPointName,
112         PipelineLayout* layout) {
113         // If the shader was destroyed, we should never call this function.
114         ASSERT(IsAlive());
115 
116         ScopedTintICEHandler scopedICEHandler(GetDevice());
117 
118         auto cacheKey = std::make_pair(layout, entryPointName);
119         VkShaderModule cachedShaderModule =
120             mTransformedShaderModuleCache->FindShaderModule(cacheKey);
121         if (cachedShaderModule != VK_NULL_HANDLE) {
122             return cachedShaderModule;
123         }
124 
125         // Creation of VkShaderModule is deferred to this point when using tint generator
126 
127         // Remap BindingNumber to BindingIndex in WGSL shader
128         using BindingRemapper = tint::transform::BindingRemapper;
129         using BindingPoint = tint::transform::BindingPoint;
130         BindingRemapper::BindingPoints bindingPoints;
131         BindingRemapper::AccessControls accessControls;
132 
133         const BindingInfoArray& moduleBindingInfo = GetEntryPoint(entryPointName).bindings;
134 
135         for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
136             const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group));
137             const auto& groupBindingInfo = moduleBindingInfo[group];
138             for (const auto& it : groupBindingInfo) {
139                 BindingNumber binding = it.first;
140                 BindingIndex bindingIndex = bgl->GetBindingIndex(binding);
141                 BindingPoint srcBindingPoint{static_cast<uint32_t>(group),
142                                              static_cast<uint32_t>(binding)};
143 
144                 BindingPoint dstBindingPoint{static_cast<uint32_t>(group),
145                                              static_cast<uint32_t>(bindingIndex)};
146                 if (srcBindingPoint != dstBindingPoint) {
147                     bindingPoints.emplace(srcBindingPoint, dstBindingPoint);
148                 }
149             }
150         }
151 
152         tint::transform::Manager transformManager;
153         transformManager.append(std::make_unique<tint::transform::BindingRemapper>());
154         // Many Vulkan drivers can't handle multi-entrypoint shader modules.
155         transformManager.append(std::make_unique<tint::transform::SingleEntryPoint>());
156 
157         tint::transform::DataMap transformInputs;
158         transformInputs.Add<BindingRemapper::Remappings>(std::move(bindingPoints),
159                                                          std::move(accessControls),
160                                                          /* mayCollide */ false);
161         transformInputs.Add<tint::transform::SingleEntryPoint::Config>(entryPointName);
162 
163         tint::Program program;
164         DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, GetTintProgram(), transformInputs,
165                                                nullptr, nullptr));
166 
167         tint::writer::spirv::Options options;
168         options.emit_vertex_point_size = true;
169         options.disable_workgroup_init = GetDevice()->IsToggleEnabled(Toggle::DisableWorkgroupInit);
170         auto result = tint::writer::spirv::Generate(&program, options);
171         DAWN_INVALID_IF(!result.success, "An error occured while generating SPIR-V: %s.",
172                         result.error);
173 
174         std::vector<uint32_t> spirv = std::move(result.spirv);
175         DAWN_TRY(
176             ValidateSpirv(GetDevice(), spirv, GetDevice()->IsToggleEnabled(Toggle::DumpShaders)));
177 
178         VkShaderModuleCreateInfo createInfo;
179         createInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
180         createInfo.pNext = nullptr;
181         createInfo.flags = 0;
182         createInfo.codeSize = spirv.size() * sizeof(uint32_t);
183         createInfo.pCode = spirv.data();
184 
185         Device* device = ToBackend(GetDevice());
186 
187         VkShaderModule newHandle = VK_NULL_HANDLE;
188 
189         DAWN_TRY(CheckVkSuccess(
190             device->fn.CreateShaderModule(device->GetVkDevice(), &createInfo, nullptr, &*newHandle),
191             "CreateShaderModule"));
192         if (newHandle != VK_NULL_HANDLE) {
193             newHandle =
194                 mTransformedShaderModuleCache->AddOrGetCachedShaderModule(cacheKey, newHandle);
195         }
196 
197         SetDebugName(ToBackend(GetDevice()), VK_OBJECT_TYPE_SHADER_MODULE,
198                      reinterpret_cast<uint64_t&>(newHandle), "Dawn_ShaderModule", GetLabel());
199 
200         return newHandle;
201     }
202 
203 }}  // namespace dawn_native::vulkan
204