1 // Copyright 2021 The Dawn Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "tests/DawnNativeTest.h"
16
17 #include "dawn_native/CommandBuffer.h"
18 #include "dawn_native/Commands.h"
19 #include "dawn_native/ComputePassEncoder.h"
20 #include "utils/WGPUHelpers.h"
21
22 class CommandBufferEncodingTests : public DawnNativeTest {
23 protected:
ExpectCommands(dawn_native::CommandIterator * commands,std::vector<std::pair<dawn_native::Command,std::function<void (dawn_native::CommandIterator *)>>> expectedCommands)24 void ExpectCommands(dawn_native::CommandIterator* commands,
25 std::vector<std::pair<dawn_native::Command,
26 std::function<void(dawn_native::CommandIterator*)>>>
27 expectedCommands) {
28 dawn_native::Command commandId;
29 for (uint32_t commandIndex = 0; commands->NextCommandId(&commandId); ++commandIndex) {
30 ASSERT_LT(commandIndex, expectedCommands.size()) << "Unexpected command";
31 ASSERT_EQ(commandId, expectedCommands[commandIndex].first)
32 << "at command " << commandIndex;
33 expectedCommands[commandIndex].second(commands);
34 }
35 }
36 };
37
38 // Indirect dispatch validation changes the bind groups in the middle
39 // of a pass. Test that bindings are restored after the validation runs.
TEST_F(CommandBufferEncodingTests,ComputePassEncoderIndirectDispatchStateRestoration)40 TEST_F(CommandBufferEncodingTests, ComputePassEncoderIndirectDispatchStateRestoration) {
41 using namespace dawn_native;
42
43 wgpu::BindGroupLayout staticLayout =
44 utils::MakeBindGroupLayout(device, {{
45 0,
46 wgpu::ShaderStage::Compute,
47 wgpu::BufferBindingType::Uniform,
48 }});
49
50 wgpu::BindGroupLayout dynamicLayout =
51 utils::MakeBindGroupLayout(device, {{
52 0,
53 wgpu::ShaderStage::Compute,
54 wgpu::BufferBindingType::Uniform,
55 true,
56 }});
57
58 // Create a simple pipeline
59 wgpu::ComputePipelineDescriptor csDesc;
60 csDesc.compute.module = utils::CreateShaderModule(device, R"(
61 [[stage(compute), workgroup_size(1, 1, 1)]]
62 fn main() {
63 })");
64 csDesc.compute.entryPoint = "main";
65
66 wgpu::PipelineLayout pl0 = utils::MakePipelineLayout(device, {staticLayout, dynamicLayout});
67 csDesc.layout = pl0;
68 wgpu::ComputePipeline pipeline0 = device.CreateComputePipeline(&csDesc);
69
70 wgpu::PipelineLayout pl1 = utils::MakePipelineLayout(device, {dynamicLayout, staticLayout});
71 csDesc.layout = pl1;
72 wgpu::ComputePipeline pipeline1 = device.CreateComputePipeline(&csDesc);
73
74 // Create buffers to use for both the indirect buffer and the bind groups.
75 wgpu::Buffer indirectBuffer =
76 utils::CreateBufferFromData<uint32_t>(device, wgpu::BufferUsage::Indirect, {1, 2, 3, 4});
77
78 wgpu::BufferDescriptor uniformBufferDesc = {};
79 uniformBufferDesc.size = 512;
80 uniformBufferDesc.usage = wgpu::BufferUsage::Uniform;
81 wgpu::Buffer uniformBuffer = device.CreateBuffer(&uniformBufferDesc);
82
83 wgpu::BindGroup staticBG = utils::MakeBindGroup(device, staticLayout, {{0, uniformBuffer}});
84
85 wgpu::BindGroup dynamicBG =
86 utils::MakeBindGroup(device, dynamicLayout, {{0, uniformBuffer, 0, 256}});
87
88 uint32_t dynamicOffset = 256;
89 std::vector<uint32_t> emptyDynamicOffsets = {};
90 std::vector<uint32_t> singleDynamicOffset = {dynamicOffset};
91
92 // Begin encoding commands.
93 wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
94 wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
95
96 CommandBufferStateTracker* stateTracker =
97 FromAPI(pass.Get())->GetCommandBufferStateTrackerForTesting();
98
99 // Perform a dispatch indirect which will be preceded by a validation dispatch.
100 pass.SetPipeline(pipeline0);
101 pass.SetBindGroup(0, staticBG);
102 pass.SetBindGroup(1, dynamicBG, 1, &dynamicOffset);
103 EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline0.Get());
104
105 pass.DispatchIndirect(indirectBuffer, 0);
106
107 // Expect restored state.
108 EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline0.Get());
109 EXPECT_EQ(ToAPI(stateTracker->GetPipelineLayout()), pl0.Get());
110 EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(0))), staticBG.Get());
111 EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(0)), emptyDynamicOffsets);
112 EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(1))), dynamicBG.Get());
113 EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(1)), singleDynamicOffset);
114
115 // Dispatch again to check that the restored state can be used.
116 // Also pass an indirect offset which should get replaced with the offset
117 // into the scratch indirect buffer (0).
118 pass.DispatchIndirect(indirectBuffer, 4);
119
120 // Expect restored state.
121 EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline0.Get());
122 EXPECT_EQ(ToAPI(stateTracker->GetPipelineLayout()), pl0.Get());
123 EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(0))), staticBG.Get());
124 EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(0)), emptyDynamicOffsets);
125 EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(1))), dynamicBG.Get());
126 EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(1)), singleDynamicOffset);
127
128 // Change the pipeline
129 pass.SetPipeline(pipeline1);
130 pass.SetBindGroup(0, dynamicBG, 1, &dynamicOffset);
131 pass.SetBindGroup(1, staticBG);
132 EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline1.Get());
133 EXPECT_EQ(ToAPI(stateTracker->GetPipelineLayout()), pl1.Get());
134
135 pass.DispatchIndirect(indirectBuffer, 0);
136
137 // Expect restored state.
138 EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline1.Get());
139 EXPECT_EQ(ToAPI(stateTracker->GetPipelineLayout()), pl1.Get());
140 EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(0))), dynamicBG.Get());
141 EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(0)), singleDynamicOffset);
142 EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(1))), staticBG.Get());
143 EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(1)), emptyDynamicOffsets);
144
145 pass.EndPass();
146
147 wgpu::CommandBuffer commandBuffer = encoder.Finish();
148
149 auto ExpectSetPipeline = [](wgpu::ComputePipeline pipeline) {
150 return [pipeline](CommandIterator* commands) {
151 auto* cmd = commands->NextCommand<SetComputePipelineCmd>();
152 EXPECT_EQ(ToAPI(cmd->pipeline.Get()), pipeline.Get());
153 };
154 };
155
156 auto ExpectSetBindGroup = [](uint32_t index, wgpu::BindGroup bg,
157 std::vector<uint32_t> offsets = {}) {
158 return [index, bg, offsets](CommandIterator* commands) {
159 auto* cmd = commands->NextCommand<SetBindGroupCmd>();
160 uint32_t* dynamicOffsets = nullptr;
161 if (cmd->dynamicOffsetCount > 0) {
162 dynamicOffsets = commands->NextData<uint32_t>(cmd->dynamicOffsetCount);
163 }
164
165 ASSERT_EQ(cmd->index, BindGroupIndex(index));
166 ASSERT_EQ(ToAPI(cmd->group.Get()), bg.Get());
167 ASSERT_EQ(cmd->dynamicOffsetCount, offsets.size());
168 for (uint32_t i = 0; i < cmd->dynamicOffsetCount; ++i) {
169 ASSERT_EQ(dynamicOffsets[i], offsets[i]);
170 }
171 };
172 };
173
174 // Initialize as null. Once we know the pointer, we'll check
175 // that it's the same buffer every time.
176 WGPUBuffer indirectScratchBuffer = nullptr;
177 auto ExpectDispatchIndirect = [&](CommandIterator* commands) {
178 auto* cmd = commands->NextCommand<DispatchIndirectCmd>();
179 if (indirectScratchBuffer == nullptr) {
180 indirectScratchBuffer = ToAPI(cmd->indirectBuffer.Get());
181 }
182 ASSERT_EQ(ToAPI(cmd->indirectBuffer.Get()), indirectScratchBuffer);
183 ASSERT_EQ(cmd->indirectOffset, uint64_t(0));
184 };
185
186 // Initialize as null. Once we know the pointer, we'll check
187 // that it's the same pipeline every time.
188 WGPUComputePipeline validationPipeline = nullptr;
189 auto ExpectSetValidationPipeline = [&](CommandIterator* commands) {
190 auto* cmd = commands->NextCommand<SetComputePipelineCmd>();
191 WGPUComputePipeline pipeline = ToAPI(cmd->pipeline.Get());
192 if (validationPipeline != nullptr) {
193 EXPECT_EQ(pipeline, validationPipeline);
194 } else {
195 EXPECT_NE(pipeline, nullptr);
196 validationPipeline = pipeline;
197 }
198 };
199
200 auto ExpectSetValidationBindGroup = [&](CommandIterator* commands) {
201 auto* cmd = commands->NextCommand<SetBindGroupCmd>();
202 ASSERT_EQ(cmd->index, BindGroupIndex(0));
203 ASSERT_NE(cmd->group.Get(), nullptr);
204 ASSERT_EQ(cmd->dynamicOffsetCount, 0u);
205 };
206
207 auto ExpectSetValidationDispatch = [&](CommandIterator* commands) {
208 auto* cmd = commands->NextCommand<DispatchCmd>();
209 ASSERT_EQ(cmd->x, 1u);
210 ASSERT_EQ(cmd->y, 1u);
211 ASSERT_EQ(cmd->z, 1u);
212 };
213
214 ExpectCommands(
215 FromAPI(commandBuffer.Get())->GetCommandIteratorForTesting(),
216 {
217 {Command::BeginComputePass,
218 [&](CommandIterator* commands) { SkipCommand(commands, Command::BeginComputePass); }},
219 // Expect the state to be set.
220 {Command::SetComputePipeline, ExpectSetPipeline(pipeline0)},
221 {Command::SetBindGroup, ExpectSetBindGroup(0, staticBG)},
222 {Command::SetBindGroup, ExpectSetBindGroup(1, dynamicBG, {dynamicOffset})},
223
224 // Expect the validation.
225 {Command::SetComputePipeline, ExpectSetValidationPipeline},
226 {Command::SetBindGroup, ExpectSetValidationBindGroup},
227 {Command::Dispatch, ExpectSetValidationDispatch},
228
229 // Expect the state to be restored.
230 {Command::SetComputePipeline, ExpectSetPipeline(pipeline0)},
231 {Command::SetBindGroup, ExpectSetBindGroup(0, staticBG)},
232 {Command::SetBindGroup, ExpectSetBindGroup(1, dynamicBG, {dynamicOffset})},
233
234 // Expect the dispatchIndirect.
235 {Command::DispatchIndirect, ExpectDispatchIndirect},
236
237 // Expect the validation.
238 {Command::SetComputePipeline, ExpectSetValidationPipeline},
239 {Command::SetBindGroup, ExpectSetValidationBindGroup},
240 {Command::Dispatch, ExpectSetValidationDispatch},
241
242 // Expect the state to be restored.
243 {Command::SetComputePipeline, ExpectSetPipeline(pipeline0)},
244 {Command::SetBindGroup, ExpectSetBindGroup(0, staticBG)},
245 {Command::SetBindGroup, ExpectSetBindGroup(1, dynamicBG, {dynamicOffset})},
246
247 // Expect the dispatchIndirect.
248 {Command::DispatchIndirect, ExpectDispatchIndirect},
249
250 // Expect the state to be set (new pipeline).
251 {Command::SetComputePipeline, ExpectSetPipeline(pipeline1)},
252 {Command::SetBindGroup, ExpectSetBindGroup(0, dynamicBG, {dynamicOffset})},
253 {Command::SetBindGroup, ExpectSetBindGroup(1, staticBG)},
254
255 // Expect the validation.
256 {Command::SetComputePipeline, ExpectSetValidationPipeline},
257 {Command::SetBindGroup, ExpectSetValidationBindGroup},
258 {Command::Dispatch, ExpectSetValidationDispatch},
259
260 // Expect the state to be restored.
261 {Command::SetComputePipeline, ExpectSetPipeline(pipeline1)},
262 {Command::SetBindGroup, ExpectSetBindGroup(0, dynamicBG, {dynamicOffset})},
263 {Command::SetBindGroup, ExpectSetBindGroup(1, staticBG)},
264
265 // Expect the dispatchIndirect.
266 {Command::DispatchIndirect, ExpectDispatchIndirect},
267
268 {Command::EndComputePass,
269 [&](CommandIterator* commands) { commands->NextCommand<EndComputePassCmd>(); }},
270 });
271 }
272
273 // Test that after restoring state, it is fully applied to the state tracker
274 // and does not leak state changes that occured between a snapshot and the
275 // state restoration.
TEST_F(CommandBufferEncodingTests,StateNotLeakedAfterRestore)276 TEST_F(CommandBufferEncodingTests, StateNotLeakedAfterRestore) {
277 using namespace dawn_native;
278
279 wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
280 wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
281
282 CommandBufferStateTracker* stateTracker =
283 FromAPI(pass.Get())->GetCommandBufferStateTrackerForTesting();
284
285 // Snapshot the state.
286 CommandBufferStateTracker snapshot = *stateTracker;
287 // Expect no pipeline in the snapshot
288 EXPECT_FALSE(snapshot.HasPipeline());
289
290 // Create a simple pipeline
291 wgpu::ComputePipelineDescriptor csDesc;
292 csDesc.compute.module = utils::CreateShaderModule(device, R"(
293 [[stage(compute), workgroup_size(1, 1, 1)]]
294 fn main() {
295 })");
296 csDesc.compute.entryPoint = "main";
297 wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&csDesc);
298
299 // Set the pipeline.
300 pass.SetPipeline(pipeline);
301
302 // Expect the pipeline to be set.
303 EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline.Get());
304
305 // Restore the state.
306 FromAPI(pass.Get())->RestoreCommandBufferStateForTesting(std::move(snapshot));
307
308 // Expect no pipeline
309 EXPECT_FALSE(stateTracker->HasPipeline());
310 }
311