• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Dawn Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "tests/DawnNativeTest.h"
16 
17 #include "dawn_native/CommandBuffer.h"
18 #include "dawn_native/Commands.h"
19 #include "dawn_native/ComputePassEncoder.h"
20 #include "utils/WGPUHelpers.h"
21 
22 class CommandBufferEncodingTests : public DawnNativeTest {
23   protected:
ExpectCommands(dawn_native::CommandIterator * commands,std::vector<std::pair<dawn_native::Command,std::function<void (dawn_native::CommandIterator *)>>> expectedCommands)24     void ExpectCommands(dawn_native::CommandIterator* commands,
25                         std::vector<std::pair<dawn_native::Command,
26                                               std::function<void(dawn_native::CommandIterator*)>>>
27                             expectedCommands) {
28         dawn_native::Command commandId;
29         for (uint32_t commandIndex = 0; commands->NextCommandId(&commandId); ++commandIndex) {
30             ASSERT_LT(commandIndex, expectedCommands.size()) << "Unexpected command";
31             ASSERT_EQ(commandId, expectedCommands[commandIndex].first)
32                 << "at command " << commandIndex;
33             expectedCommands[commandIndex].second(commands);
34         }
35     }
36 };
37 
38 // Indirect dispatch validation changes the bind groups in the middle
39 // of a pass. Test that bindings are restored after the validation runs.
TEST_F(CommandBufferEncodingTests,ComputePassEncoderIndirectDispatchStateRestoration)40 TEST_F(CommandBufferEncodingTests, ComputePassEncoderIndirectDispatchStateRestoration) {
41     using namespace dawn_native;
42 
43     wgpu::BindGroupLayout staticLayout =
44         utils::MakeBindGroupLayout(device, {{
45                                                0,
46                                                wgpu::ShaderStage::Compute,
47                                                wgpu::BufferBindingType::Uniform,
48                                            }});
49 
50     wgpu::BindGroupLayout dynamicLayout =
51         utils::MakeBindGroupLayout(device, {{
52                                                0,
53                                                wgpu::ShaderStage::Compute,
54                                                wgpu::BufferBindingType::Uniform,
55                                                true,
56                                            }});
57 
58     // Create a simple pipeline
59     wgpu::ComputePipelineDescriptor csDesc;
60     csDesc.compute.module = utils::CreateShaderModule(device, R"(
61         [[stage(compute), workgroup_size(1, 1, 1)]]
62         fn main() {
63         })");
64     csDesc.compute.entryPoint = "main";
65 
66     wgpu::PipelineLayout pl0 = utils::MakePipelineLayout(device, {staticLayout, dynamicLayout});
67     csDesc.layout = pl0;
68     wgpu::ComputePipeline pipeline0 = device.CreateComputePipeline(&csDesc);
69 
70     wgpu::PipelineLayout pl1 = utils::MakePipelineLayout(device, {dynamicLayout, staticLayout});
71     csDesc.layout = pl1;
72     wgpu::ComputePipeline pipeline1 = device.CreateComputePipeline(&csDesc);
73 
74     // Create buffers to use for both the indirect buffer and the bind groups.
75     wgpu::Buffer indirectBuffer =
76         utils::CreateBufferFromData<uint32_t>(device, wgpu::BufferUsage::Indirect, {1, 2, 3, 4});
77 
78     wgpu::BufferDescriptor uniformBufferDesc = {};
79     uniformBufferDesc.size = 512;
80     uniformBufferDesc.usage = wgpu::BufferUsage::Uniform;
81     wgpu::Buffer uniformBuffer = device.CreateBuffer(&uniformBufferDesc);
82 
83     wgpu::BindGroup staticBG = utils::MakeBindGroup(device, staticLayout, {{0, uniformBuffer}});
84 
85     wgpu::BindGroup dynamicBG =
86         utils::MakeBindGroup(device, dynamicLayout, {{0, uniformBuffer, 0, 256}});
87 
88     uint32_t dynamicOffset = 256;
89     std::vector<uint32_t> emptyDynamicOffsets = {};
90     std::vector<uint32_t> singleDynamicOffset = {dynamicOffset};
91 
92     // Begin encoding commands.
93     wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
94     wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
95 
96     CommandBufferStateTracker* stateTracker =
97         FromAPI(pass.Get())->GetCommandBufferStateTrackerForTesting();
98 
99     // Perform a dispatch indirect which will be preceded by a validation dispatch.
100     pass.SetPipeline(pipeline0);
101     pass.SetBindGroup(0, staticBG);
102     pass.SetBindGroup(1, dynamicBG, 1, &dynamicOffset);
103     EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline0.Get());
104 
105     pass.DispatchIndirect(indirectBuffer, 0);
106 
107     // Expect restored state.
108     EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline0.Get());
109     EXPECT_EQ(ToAPI(stateTracker->GetPipelineLayout()), pl0.Get());
110     EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(0))), staticBG.Get());
111     EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(0)), emptyDynamicOffsets);
112     EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(1))), dynamicBG.Get());
113     EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(1)), singleDynamicOffset);
114 
115     // Dispatch again to check that the restored state can be used.
116     // Also pass an indirect offset which should get replaced with the offset
117     // into the scratch indirect buffer (0).
118     pass.DispatchIndirect(indirectBuffer, 4);
119 
120     // Expect restored state.
121     EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline0.Get());
122     EXPECT_EQ(ToAPI(stateTracker->GetPipelineLayout()), pl0.Get());
123     EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(0))), staticBG.Get());
124     EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(0)), emptyDynamicOffsets);
125     EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(1))), dynamicBG.Get());
126     EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(1)), singleDynamicOffset);
127 
128     // Change the pipeline
129     pass.SetPipeline(pipeline1);
130     pass.SetBindGroup(0, dynamicBG, 1, &dynamicOffset);
131     pass.SetBindGroup(1, staticBG);
132     EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline1.Get());
133     EXPECT_EQ(ToAPI(stateTracker->GetPipelineLayout()), pl1.Get());
134 
135     pass.DispatchIndirect(indirectBuffer, 0);
136 
137     // Expect restored state.
138     EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline1.Get());
139     EXPECT_EQ(ToAPI(stateTracker->GetPipelineLayout()), pl1.Get());
140     EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(0))), dynamicBG.Get());
141     EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(0)), singleDynamicOffset);
142     EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(1))), staticBG.Get());
143     EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(1)), emptyDynamicOffsets);
144 
145     pass.EndPass();
146 
147     wgpu::CommandBuffer commandBuffer = encoder.Finish();
148 
149     auto ExpectSetPipeline = [](wgpu::ComputePipeline pipeline) {
150         return [pipeline](CommandIterator* commands) {
151             auto* cmd = commands->NextCommand<SetComputePipelineCmd>();
152             EXPECT_EQ(ToAPI(cmd->pipeline.Get()), pipeline.Get());
153         };
154     };
155 
156     auto ExpectSetBindGroup = [](uint32_t index, wgpu::BindGroup bg,
157                                  std::vector<uint32_t> offsets = {}) {
158         return [index, bg, offsets](CommandIterator* commands) {
159             auto* cmd = commands->NextCommand<SetBindGroupCmd>();
160             uint32_t* dynamicOffsets = nullptr;
161             if (cmd->dynamicOffsetCount > 0) {
162                 dynamicOffsets = commands->NextData<uint32_t>(cmd->dynamicOffsetCount);
163             }
164 
165             ASSERT_EQ(cmd->index, BindGroupIndex(index));
166             ASSERT_EQ(ToAPI(cmd->group.Get()), bg.Get());
167             ASSERT_EQ(cmd->dynamicOffsetCount, offsets.size());
168             for (uint32_t i = 0; i < cmd->dynamicOffsetCount; ++i) {
169                 ASSERT_EQ(dynamicOffsets[i], offsets[i]);
170             }
171         };
172     };
173 
174     // Initialize as null. Once we know the pointer, we'll check
175     // that it's the same buffer every time.
176     WGPUBuffer indirectScratchBuffer = nullptr;
177     auto ExpectDispatchIndirect = [&](CommandIterator* commands) {
178         auto* cmd = commands->NextCommand<DispatchIndirectCmd>();
179         if (indirectScratchBuffer == nullptr) {
180             indirectScratchBuffer = ToAPI(cmd->indirectBuffer.Get());
181         }
182         ASSERT_EQ(ToAPI(cmd->indirectBuffer.Get()), indirectScratchBuffer);
183         ASSERT_EQ(cmd->indirectOffset, uint64_t(0));
184     };
185 
186     // Initialize as null. Once we know the pointer, we'll check
187     // that it's the same pipeline every time.
188     WGPUComputePipeline validationPipeline = nullptr;
189     auto ExpectSetValidationPipeline = [&](CommandIterator* commands) {
190         auto* cmd = commands->NextCommand<SetComputePipelineCmd>();
191         WGPUComputePipeline pipeline = ToAPI(cmd->pipeline.Get());
192         if (validationPipeline != nullptr) {
193             EXPECT_EQ(pipeline, validationPipeline);
194         } else {
195             EXPECT_NE(pipeline, nullptr);
196             validationPipeline = pipeline;
197         }
198     };
199 
200     auto ExpectSetValidationBindGroup = [&](CommandIterator* commands) {
201         auto* cmd = commands->NextCommand<SetBindGroupCmd>();
202         ASSERT_EQ(cmd->index, BindGroupIndex(0));
203         ASSERT_NE(cmd->group.Get(), nullptr);
204         ASSERT_EQ(cmd->dynamicOffsetCount, 0u);
205     };
206 
207     auto ExpectSetValidationDispatch = [&](CommandIterator* commands) {
208         auto* cmd = commands->NextCommand<DispatchCmd>();
209         ASSERT_EQ(cmd->x, 1u);
210         ASSERT_EQ(cmd->y, 1u);
211         ASSERT_EQ(cmd->z, 1u);
212     };
213 
214     ExpectCommands(
215         FromAPI(commandBuffer.Get())->GetCommandIteratorForTesting(),
216         {
217             {Command::BeginComputePass,
218              [&](CommandIterator* commands) { SkipCommand(commands, Command::BeginComputePass); }},
219             // Expect the state to be set.
220             {Command::SetComputePipeline, ExpectSetPipeline(pipeline0)},
221             {Command::SetBindGroup, ExpectSetBindGroup(0, staticBG)},
222             {Command::SetBindGroup, ExpectSetBindGroup(1, dynamicBG, {dynamicOffset})},
223 
224             // Expect the validation.
225             {Command::SetComputePipeline, ExpectSetValidationPipeline},
226             {Command::SetBindGroup, ExpectSetValidationBindGroup},
227             {Command::Dispatch, ExpectSetValidationDispatch},
228 
229             // Expect the state to be restored.
230             {Command::SetComputePipeline, ExpectSetPipeline(pipeline0)},
231             {Command::SetBindGroup, ExpectSetBindGroup(0, staticBG)},
232             {Command::SetBindGroup, ExpectSetBindGroup(1, dynamicBG, {dynamicOffset})},
233 
234             // Expect the dispatchIndirect.
235             {Command::DispatchIndirect, ExpectDispatchIndirect},
236 
237             // Expect the validation.
238             {Command::SetComputePipeline, ExpectSetValidationPipeline},
239             {Command::SetBindGroup, ExpectSetValidationBindGroup},
240             {Command::Dispatch, ExpectSetValidationDispatch},
241 
242             // Expect the state to be restored.
243             {Command::SetComputePipeline, ExpectSetPipeline(pipeline0)},
244             {Command::SetBindGroup, ExpectSetBindGroup(0, staticBG)},
245             {Command::SetBindGroup, ExpectSetBindGroup(1, dynamicBG, {dynamicOffset})},
246 
247             // Expect the dispatchIndirect.
248             {Command::DispatchIndirect, ExpectDispatchIndirect},
249 
250             // Expect the state to be set (new pipeline).
251             {Command::SetComputePipeline, ExpectSetPipeline(pipeline1)},
252             {Command::SetBindGroup, ExpectSetBindGroup(0, dynamicBG, {dynamicOffset})},
253             {Command::SetBindGroup, ExpectSetBindGroup(1, staticBG)},
254 
255             // Expect the validation.
256             {Command::SetComputePipeline, ExpectSetValidationPipeline},
257             {Command::SetBindGroup, ExpectSetValidationBindGroup},
258             {Command::Dispatch, ExpectSetValidationDispatch},
259 
260             // Expect the state to be restored.
261             {Command::SetComputePipeline, ExpectSetPipeline(pipeline1)},
262             {Command::SetBindGroup, ExpectSetBindGroup(0, dynamicBG, {dynamicOffset})},
263             {Command::SetBindGroup, ExpectSetBindGroup(1, staticBG)},
264 
265             // Expect the dispatchIndirect.
266             {Command::DispatchIndirect, ExpectDispatchIndirect},
267 
268             {Command::EndComputePass,
269              [&](CommandIterator* commands) { commands->NextCommand<EndComputePassCmd>(); }},
270         });
271 }
272 
273 // Test that after restoring state, it is fully applied to the state tracker
274 // and does not leak state changes that occured between a snapshot and the
275 // state restoration.
TEST_F(CommandBufferEncodingTests,StateNotLeakedAfterRestore)276 TEST_F(CommandBufferEncodingTests, StateNotLeakedAfterRestore) {
277     using namespace dawn_native;
278 
279     wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
280     wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
281 
282     CommandBufferStateTracker* stateTracker =
283         FromAPI(pass.Get())->GetCommandBufferStateTrackerForTesting();
284 
285     // Snapshot the state.
286     CommandBufferStateTracker snapshot = *stateTracker;
287     // Expect no pipeline in the snapshot
288     EXPECT_FALSE(snapshot.HasPipeline());
289 
290     // Create a simple pipeline
291     wgpu::ComputePipelineDescriptor csDesc;
292     csDesc.compute.module = utils::CreateShaderModule(device, R"(
293         [[stage(compute), workgroup_size(1, 1, 1)]]
294         fn main() {
295         })");
296     csDesc.compute.entryPoint = "main";
297     wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&csDesc);
298 
299     // Set the pipeline.
300     pass.SetPipeline(pipeline);
301 
302     // Expect the pipeline to be set.
303     EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline.Get());
304 
305     // Restore the state.
306     FromAPI(pass.Get())->RestoreCommandBufferStateForTesting(std::move(snapshot));
307 
308     // Expect no pipeline
309     EXPECT_FALSE(stateTracker->HasPipeline());
310 }
311