• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 The Dawn Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "tests/DawnTest.h"
16 
17 #include "utils/ComboRenderPipelineDescriptor.h"
18 #include "utils/WGPUHelpers.h"
19 
20 #define EXPECT_CACHE_HIT(N, statement)              \
21     do {                                            \
22         size_t before = mPersistentCache.mHitCount; \
23         statement;                                  \
24         FlushWire();                                \
25         size_t after = mPersistentCache.mHitCount;  \
26         EXPECT_EQ(N, after - before);               \
27     } while (0)
28 
29 // FakePersistentCache implements a in-memory persistent cache.
30 class FakePersistentCache : public dawn_platform::CachingInterface {
31   public:
32     // PersistentCache API
StoreData(const WGPUDevice device,const void * key,size_t keySize,const void * value,size_t valueSize)33     void StoreData(const WGPUDevice device,
34                    const void* key,
35                    size_t keySize,
36                    const void* value,
37                    size_t valueSize) override {
38         if (mIsDisabled)
39             return;
40         const std::string keyStr(reinterpret_cast<const char*>(key), keySize);
41 
42         const uint8_t* value_start = reinterpret_cast<const uint8_t*>(value);
43         std::vector<uint8_t> entry_value(value_start, value_start + valueSize);
44 
45         EXPECT_TRUE(mCache.insert({keyStr, std::move(entry_value)}).second);
46     }
47 
LoadData(const WGPUDevice device,const void * key,size_t keySize,void * value,size_t valueSize)48     size_t LoadData(const WGPUDevice device,
49                     const void* key,
50                     size_t keySize,
51                     void* value,
52                     size_t valueSize) override {
53         const std::string keyStr(reinterpret_cast<const char*>(key), keySize);
54         auto entry = mCache.find(keyStr);
55         if (entry == mCache.end()) {
56             return 0;
57         }
58         if (valueSize >= entry->second.size()) {
59             memcpy(value, entry->second.data(), entry->second.size());
60         }
61         mHitCount++;
62         return entry->second.size();
63     }
64 
65     using Blob = std::vector<uint8_t>;
66     using FakeCache = std::unordered_map<std::string, Blob>;
67 
68     FakeCache mCache;
69 
70     size_t mHitCount = 0;
71     bool mIsDisabled = false;
72 };
73 
74 // Test platform that only supports caching.
75 class DawnTestPlatform : public dawn_platform::Platform {
76   public:
DawnTestPlatform(dawn_platform::CachingInterface * cachingInterface)77     DawnTestPlatform(dawn_platform::CachingInterface* cachingInterface)
78         : mCachingInterface(cachingInterface) {
79     }
80     ~DawnTestPlatform() override = default;
81 
GetCachingInterface(const void * fingerprint,size_t fingerprintSize)82     dawn_platform::CachingInterface* GetCachingInterface(const void* fingerprint,
83                                                          size_t fingerprintSize) override {
84         return mCachingInterface;
85     }
86 
87     dawn_platform::CachingInterface* mCachingInterface = nullptr;
88 };
89 
90 class D3D12CachingTests : public DawnTest {
91   protected:
CreateTestPlatform()92     std::unique_ptr<dawn_platform::Platform> CreateTestPlatform() override {
93         return std::make_unique<DawnTestPlatform>(&mPersistentCache);
94     }
95 
96     FakePersistentCache mPersistentCache;
97 };
98 
99 // Test that duplicate WGSL still re-compiles HLSL even when the cache is not enabled.
TEST_P(D3D12CachingTests,SameShaderNoCache)100 TEST_P(D3D12CachingTests, SameShaderNoCache) {
101     mPersistentCache.mIsDisabled = true;
102 
103     wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
104         [[stage(vertex)]] fn vertex_main() -> [[builtin(position)]] vec4<f32> {
105             return vec4<f32>(0.0, 0.0, 0.0, 1.0);
106         }
107 
108         [[stage(fragment)]] fn fragment_main() -> [[location(0)]] vec4<f32> {
109           return vec4<f32>(1.0, 0.0, 0.0, 1.0);
110         }
111     )");
112 
113     // Store the WGSL shader into the cache.
114     {
115         utils::ComboRenderPipelineDescriptor desc;
116         desc.vertex.module = module;
117         desc.vertex.entryPoint = "vertex_main";
118         desc.cFragment.module = module;
119         desc.cFragment.entryPoint = "fragment_main";
120 
121         EXPECT_CACHE_HIT(0u, device.CreateRenderPipeline(&desc));
122     }
123 
124     EXPECT_EQ(mPersistentCache.mCache.size(), 0u);
125 
126     // Load the same WGSL shader from the cache.
127     {
128         utils::ComboRenderPipelineDescriptor desc;
129         desc.vertex.module = module;
130         desc.vertex.entryPoint = "vertex_main";
131         desc.cFragment.module = module;
132         desc.cFragment.entryPoint = "fragment_main";
133 
134         EXPECT_CACHE_HIT(0u, device.CreateRenderPipeline(&desc));
135     }
136 
137     EXPECT_EQ(mPersistentCache.mCache.size(), 0u);
138 }
139 
140 // Test creating a pipeline from two entrypoints in multiple stages will cache the correct number
141 // of HLSL shaders. WGSL shader should result into caching 2 HLSL shaders (stage x
142 // entrypoints)
TEST_P(D3D12CachingTests,ReuseShaderWithMultipleEntryPointsPerStage)143 TEST_P(D3D12CachingTests, ReuseShaderWithMultipleEntryPointsPerStage) {
144     wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
145         [[stage(vertex)]] fn vertex_main() -> [[builtin(position)]] vec4<f32> {
146             return vec4<f32>(0.0, 0.0, 0.0, 1.0);
147         }
148 
149         [[stage(fragment)]] fn fragment_main() -> [[location(0)]] vec4<f32> {
150           return vec4<f32>(1.0, 0.0, 0.0, 1.0);
151         }
152     )");
153 
154     // Store the WGSL shader into the cache.
155     {
156         utils::ComboRenderPipelineDescriptor desc;
157         desc.vertex.module = module;
158         desc.vertex.entryPoint = "vertex_main";
159         desc.cFragment.module = module;
160         desc.cFragment.entryPoint = "fragment_main";
161 
162         EXPECT_CACHE_HIT(0u, device.CreateRenderPipeline(&desc));
163     }
164 
165     EXPECT_EQ(mPersistentCache.mCache.size(), 2u);
166 
167     // Load the same WGSL shader from the cache.
168     {
169         utils::ComboRenderPipelineDescriptor desc;
170         desc.vertex.module = module;
171         desc.vertex.entryPoint = "vertex_main";
172         desc.cFragment.module = module;
173         desc.cFragment.entryPoint = "fragment_main";
174 
175         // Cached HLSL shader calls LoadData twice (once to peek, again to get), so check 2 x
176         // kNumOfShaders hits.
177         EXPECT_CACHE_HIT(4u, device.CreateRenderPipeline(&desc));
178     }
179 
180     EXPECT_EQ(mPersistentCache.mCache.size(), 2u);
181 
182     // Modify the WGSL shader functions and make sure it doesn't hit.
183     wgpu::ShaderModule newModule = utils::CreateShaderModule(device, R"(
184       [[stage(vertex)]] fn vertex_main() -> [[builtin(position)]] vec4<f32> {
185           return vec4<f32>(1.0, 1.0, 1.0, 1.0);
186       }
187 
188       [[stage(fragment)]] fn fragment_main() -> [[location(0)]] vec4<f32> {
189         return vec4<f32>(1.0, 1.0, 1.0, 1.0);
190       }
191   )");
192 
193     {
194         utils::ComboRenderPipelineDescriptor desc;
195         desc.vertex.module = newModule;
196         desc.vertex.entryPoint = "vertex_main";
197         desc.cFragment.module = newModule;
198         desc.cFragment.entryPoint = "fragment_main";
199         EXPECT_CACHE_HIT(0u, device.CreateRenderPipeline(&desc));
200     }
201 
202     // Cached HLSL shader calls LoadData twice (once to peek, again to get), so check 2 x
203     // kNumOfShaders hits.
204     EXPECT_EQ(mPersistentCache.mCache.size(), 4u);
205 }
206 
207 // Test creating a WGSL shader with two entrypoints in the same stage will cache the correct number
208 // of HLSL shaders. WGSL shader should result into caching 1 HLSL shader (stage x entrypoints)
TEST_P(D3D12CachingTests,ReuseShaderWithMultipleEntryPoints)209 TEST_P(D3D12CachingTests, ReuseShaderWithMultipleEntryPoints) {
210     wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
211         [[block]] struct Data {
212             data : u32;
213         };
214         [[binding(0), group(0)]] var<storage, read_write> data : Data;
215 
216         [[stage(compute), workgroup_size(1)]] fn write1() {
217             data.data = 1u;
218         }
219 
220         [[stage(compute), workgroup_size(1)]] fn write42() {
221             data.data = 42u;
222         }
223     )");
224 
225     // Store the WGSL shader into the cache.
226     {
227         wgpu::ComputePipelineDescriptor desc;
228         desc.compute.module = module;
229         desc.compute.entryPoint = "write1";
230         EXPECT_CACHE_HIT(0u, device.CreateComputePipeline(&desc));
231 
232         desc.compute.module = module;
233         desc.compute.entryPoint = "write42";
234         EXPECT_CACHE_HIT(0u, device.CreateComputePipeline(&desc));
235     }
236 
237     EXPECT_EQ(mPersistentCache.mCache.size(), 2u);
238 
239     // Load the same WGSL shader from the cache.
240     {
241         wgpu::ComputePipelineDescriptor desc;
242         desc.compute.module = module;
243         desc.compute.entryPoint = "write1";
244 
245         // Cached HLSL shader calls LoadData twice (once to peek, again to get), so check 2 x
246         // kNumOfShaders hits.
247         EXPECT_CACHE_HIT(2u, device.CreateComputePipeline(&desc));
248 
249         desc.compute.module = module;
250         desc.compute.entryPoint = "write42";
251 
252         // Cached HLSL shader calls LoadData twice, so check 2 x kNumOfShaders hits.
253         EXPECT_CACHE_HIT(2u, device.CreateComputePipeline(&desc));
254     }
255 
256     EXPECT_EQ(mPersistentCache.mCache.size(), 2u);
257 }
258 
259 DAWN_INSTANTIATE_TEST(D3D12CachingTests, D3D12Backend());
260