1 // Copyright 2020 The Dawn Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "tests/DawnTest.h"
16
17 #include "utils/ComboRenderPipelineDescriptor.h"
18 #include "utils/WGPUHelpers.h"
19
20 #define EXPECT_CACHE_HIT(N, statement) \
21 do { \
22 size_t before = mPersistentCache.mHitCount; \
23 statement; \
24 FlushWire(); \
25 size_t after = mPersistentCache.mHitCount; \
26 EXPECT_EQ(N, after - before); \
27 } while (0)
28
29 // FakePersistentCache implements a in-memory persistent cache.
30 class FakePersistentCache : public dawn_platform::CachingInterface {
31 public:
32 // PersistentCache API
StoreData(const WGPUDevice device,const void * key,size_t keySize,const void * value,size_t valueSize)33 void StoreData(const WGPUDevice device,
34 const void* key,
35 size_t keySize,
36 const void* value,
37 size_t valueSize) override {
38 if (mIsDisabled)
39 return;
40 const std::string keyStr(reinterpret_cast<const char*>(key), keySize);
41
42 const uint8_t* value_start = reinterpret_cast<const uint8_t*>(value);
43 std::vector<uint8_t> entry_value(value_start, value_start + valueSize);
44
45 EXPECT_TRUE(mCache.insert({keyStr, std::move(entry_value)}).second);
46 }
47
LoadData(const WGPUDevice device,const void * key,size_t keySize,void * value,size_t valueSize)48 size_t LoadData(const WGPUDevice device,
49 const void* key,
50 size_t keySize,
51 void* value,
52 size_t valueSize) override {
53 const std::string keyStr(reinterpret_cast<const char*>(key), keySize);
54 auto entry = mCache.find(keyStr);
55 if (entry == mCache.end()) {
56 return 0;
57 }
58 if (valueSize >= entry->second.size()) {
59 memcpy(value, entry->second.data(), entry->second.size());
60 }
61 mHitCount++;
62 return entry->second.size();
63 }
64
65 using Blob = std::vector<uint8_t>;
66 using FakeCache = std::unordered_map<std::string, Blob>;
67
68 FakeCache mCache;
69
70 size_t mHitCount = 0;
71 bool mIsDisabled = false;
72 };
73
74 // Test platform that only supports caching.
75 class DawnTestPlatform : public dawn_platform::Platform {
76 public:
DawnTestPlatform(dawn_platform::CachingInterface * cachingInterface)77 DawnTestPlatform(dawn_platform::CachingInterface* cachingInterface)
78 : mCachingInterface(cachingInterface) {
79 }
80 ~DawnTestPlatform() override = default;
81
GetCachingInterface(const void * fingerprint,size_t fingerprintSize)82 dawn_platform::CachingInterface* GetCachingInterface(const void* fingerprint,
83 size_t fingerprintSize) override {
84 return mCachingInterface;
85 }
86
87 dawn_platform::CachingInterface* mCachingInterface = nullptr;
88 };
89
90 class D3D12CachingTests : public DawnTest {
91 protected:
CreateTestPlatform()92 std::unique_ptr<dawn_platform::Platform> CreateTestPlatform() override {
93 return std::make_unique<DawnTestPlatform>(&mPersistentCache);
94 }
95
96 FakePersistentCache mPersistentCache;
97 };
98
99 // Test that duplicate WGSL still re-compiles HLSL even when the cache is not enabled.
TEST_P(D3D12CachingTests,SameShaderNoCache)100 TEST_P(D3D12CachingTests, SameShaderNoCache) {
101 mPersistentCache.mIsDisabled = true;
102
103 wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
104 [[stage(vertex)]] fn vertex_main() -> [[builtin(position)]] vec4<f32> {
105 return vec4<f32>(0.0, 0.0, 0.0, 1.0);
106 }
107
108 [[stage(fragment)]] fn fragment_main() -> [[location(0)]] vec4<f32> {
109 return vec4<f32>(1.0, 0.0, 0.0, 1.0);
110 }
111 )");
112
113 // Store the WGSL shader into the cache.
114 {
115 utils::ComboRenderPipelineDescriptor desc;
116 desc.vertex.module = module;
117 desc.vertex.entryPoint = "vertex_main";
118 desc.cFragment.module = module;
119 desc.cFragment.entryPoint = "fragment_main";
120
121 EXPECT_CACHE_HIT(0u, device.CreateRenderPipeline(&desc));
122 }
123
124 EXPECT_EQ(mPersistentCache.mCache.size(), 0u);
125
126 // Load the same WGSL shader from the cache.
127 {
128 utils::ComboRenderPipelineDescriptor desc;
129 desc.vertex.module = module;
130 desc.vertex.entryPoint = "vertex_main";
131 desc.cFragment.module = module;
132 desc.cFragment.entryPoint = "fragment_main";
133
134 EXPECT_CACHE_HIT(0u, device.CreateRenderPipeline(&desc));
135 }
136
137 EXPECT_EQ(mPersistentCache.mCache.size(), 0u);
138 }
139
140 // Test creating a pipeline from two entrypoints in multiple stages will cache the correct number
141 // of HLSL shaders. WGSL shader should result into caching 2 HLSL shaders (stage x
142 // entrypoints)
TEST_P(D3D12CachingTests,ReuseShaderWithMultipleEntryPointsPerStage)143 TEST_P(D3D12CachingTests, ReuseShaderWithMultipleEntryPointsPerStage) {
144 wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
145 [[stage(vertex)]] fn vertex_main() -> [[builtin(position)]] vec4<f32> {
146 return vec4<f32>(0.0, 0.0, 0.0, 1.0);
147 }
148
149 [[stage(fragment)]] fn fragment_main() -> [[location(0)]] vec4<f32> {
150 return vec4<f32>(1.0, 0.0, 0.0, 1.0);
151 }
152 )");
153
154 // Store the WGSL shader into the cache.
155 {
156 utils::ComboRenderPipelineDescriptor desc;
157 desc.vertex.module = module;
158 desc.vertex.entryPoint = "vertex_main";
159 desc.cFragment.module = module;
160 desc.cFragment.entryPoint = "fragment_main";
161
162 EXPECT_CACHE_HIT(0u, device.CreateRenderPipeline(&desc));
163 }
164
165 EXPECT_EQ(mPersistentCache.mCache.size(), 2u);
166
167 // Load the same WGSL shader from the cache.
168 {
169 utils::ComboRenderPipelineDescriptor desc;
170 desc.vertex.module = module;
171 desc.vertex.entryPoint = "vertex_main";
172 desc.cFragment.module = module;
173 desc.cFragment.entryPoint = "fragment_main";
174
175 // Cached HLSL shader calls LoadData twice (once to peek, again to get), so check 2 x
176 // kNumOfShaders hits.
177 EXPECT_CACHE_HIT(4u, device.CreateRenderPipeline(&desc));
178 }
179
180 EXPECT_EQ(mPersistentCache.mCache.size(), 2u);
181
182 // Modify the WGSL shader functions and make sure it doesn't hit.
183 wgpu::ShaderModule newModule = utils::CreateShaderModule(device, R"(
184 [[stage(vertex)]] fn vertex_main() -> [[builtin(position)]] vec4<f32> {
185 return vec4<f32>(1.0, 1.0, 1.0, 1.0);
186 }
187
188 [[stage(fragment)]] fn fragment_main() -> [[location(0)]] vec4<f32> {
189 return vec4<f32>(1.0, 1.0, 1.0, 1.0);
190 }
191 )");
192
193 {
194 utils::ComboRenderPipelineDescriptor desc;
195 desc.vertex.module = newModule;
196 desc.vertex.entryPoint = "vertex_main";
197 desc.cFragment.module = newModule;
198 desc.cFragment.entryPoint = "fragment_main";
199 EXPECT_CACHE_HIT(0u, device.CreateRenderPipeline(&desc));
200 }
201
202 // Cached HLSL shader calls LoadData twice (once to peek, again to get), so check 2 x
203 // kNumOfShaders hits.
204 EXPECT_EQ(mPersistentCache.mCache.size(), 4u);
205 }
206
207 // Test creating a WGSL shader with two entrypoints in the same stage will cache the correct number
208 // of HLSL shaders. WGSL shader should result into caching 1 HLSL shader (stage x entrypoints)
TEST_P(D3D12CachingTests,ReuseShaderWithMultipleEntryPoints)209 TEST_P(D3D12CachingTests, ReuseShaderWithMultipleEntryPoints) {
210 wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
211 [[block]] struct Data {
212 data : u32;
213 };
214 [[binding(0), group(0)]] var<storage, read_write> data : Data;
215
216 [[stage(compute), workgroup_size(1)]] fn write1() {
217 data.data = 1u;
218 }
219
220 [[stage(compute), workgroup_size(1)]] fn write42() {
221 data.data = 42u;
222 }
223 )");
224
225 // Store the WGSL shader into the cache.
226 {
227 wgpu::ComputePipelineDescriptor desc;
228 desc.compute.module = module;
229 desc.compute.entryPoint = "write1";
230 EXPECT_CACHE_HIT(0u, device.CreateComputePipeline(&desc));
231
232 desc.compute.module = module;
233 desc.compute.entryPoint = "write42";
234 EXPECT_CACHE_HIT(0u, device.CreateComputePipeline(&desc));
235 }
236
237 EXPECT_EQ(mPersistentCache.mCache.size(), 2u);
238
239 // Load the same WGSL shader from the cache.
240 {
241 wgpu::ComputePipelineDescriptor desc;
242 desc.compute.module = module;
243 desc.compute.entryPoint = "write1";
244
245 // Cached HLSL shader calls LoadData twice (once to peek, again to get), so check 2 x
246 // kNumOfShaders hits.
247 EXPECT_CACHE_HIT(2u, device.CreateComputePipeline(&desc));
248
249 desc.compute.module = module;
250 desc.compute.entryPoint = "write42";
251
252 // Cached HLSL shader calls LoadData twice, so check 2 x kNumOfShaders hits.
253 EXPECT_CACHE_HIT(2u, device.CreateComputePipeline(&desc));
254 }
255
256 EXPECT_EQ(mPersistentCache.mCache.size(), 2u);
257 }
258
259 DAWN_INSTANTIATE_TEST(D3D12CachingTests, D3D12Backend());
260