• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2018 The Dawn Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "tests/DawnTest.h"
16 
17 #include "utils/WGPUHelpers.h"
18 
19 #include <array>
20 
21 class ComputeCopyStorageBufferTests : public DawnTest {
22   public:
23     static constexpr int kInstances = 4;
24     static constexpr int kUintsPerInstance = 4;
25     static constexpr int kNumUints = kInstances * kUintsPerInstance;
26 
27     void BasicTest(const char* shader);
28 };
29 
BasicTest(const char * shader)30 void ComputeCopyStorageBufferTests::BasicTest(const char* shader) {
31     // Set up shader and pipeline
32     auto module = utils::CreateShaderModule(device, shader);
33 
34     wgpu::ComputePipelineDescriptor csDesc;
35     csDesc.compute.module = module;
36     csDesc.compute.entryPoint = "main";
37 
38     wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&csDesc);
39 
40     // Set up src storage buffer
41     wgpu::BufferDescriptor srcDesc;
42     srcDesc.size = kNumUints * sizeof(uint32_t);
43     srcDesc.usage =
44         wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst;
45     wgpu::Buffer src = device.CreateBuffer(&srcDesc);
46 
47     std::array<uint32_t, kNumUints> expected;
48     for (uint32_t i = 0; i < kNumUints; ++i) {
49         expected[i] = (i + 1u) * 0x11111111u;
50     }
51     queue.WriteBuffer(src, 0, expected.data(), sizeof(expected));
52     EXPECT_BUFFER_U32_RANGE_EQ(expected.data(), src, 0, kNumUints);
53 
54     // Set up dst storage buffer
55     wgpu::BufferDescriptor dstDesc;
56     dstDesc.size = kNumUints * sizeof(uint32_t);
57     dstDesc.usage =
58         wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst;
59     wgpu::Buffer dst = device.CreateBuffer(&dstDesc);
60 
61     std::array<uint32_t, kNumUints> zero{};
62     queue.WriteBuffer(dst, 0, zero.data(), sizeof(zero));
63 
64     // Set up bind group and issue dispatch
65     wgpu::BindGroup bindGroup = utils::MakeBindGroup(device, pipeline.GetBindGroupLayout(0),
66                                                      {
67                                                          {0, src, 0, kNumUints * sizeof(uint32_t)},
68                                                          {1, dst, 0, kNumUints * sizeof(uint32_t)},
69                                                      });
70 
71     wgpu::CommandBuffer commands;
72     {
73         wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
74         wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
75         pass.SetPipeline(pipeline);
76         pass.SetBindGroup(0, bindGroup);
77         pass.Dispatch(kInstances);
78         pass.EndPass();
79 
80         commands = encoder.Finish();
81     }
82 
83     queue.Submit(1, &commands);
84 
85     EXPECT_BUFFER_U32_RANGE_EQ(expected.data(), dst, 0, kNumUints);
86 }
87 
88 // Test that a trivial compute-shader memcpy implementation works.
TEST_P(ComputeCopyStorageBufferTests,SizedArrayOfBasic)89 TEST_P(ComputeCopyStorageBufferTests, SizedArrayOfBasic) {
90     BasicTest(R"(
91         [[block]] struct Buf {
92             s : array<vec4<u32>, 4>;
93         };
94 
95         [[group(0), binding(0)]] var<storage, read_write> src : Buf;
96         [[group(0), binding(1)]] var<storage, read_write> dst : Buf;
97 
98         [[stage(compute), workgroup_size(1)]]
99         fn main([[builtin(global_invocation_id)]] GlobalInvocationID : vec3<u32>) {
100             let index : u32 = GlobalInvocationID.x;
101             if (index >= 4u) { return; }
102             dst.s[index] = src.s[index];
103         })");
104 }
105 
106 // Test that a slightly-less-trivial compute-shader memcpy implementation works.
TEST_P(ComputeCopyStorageBufferTests,SizedArrayOfStruct)107 TEST_P(ComputeCopyStorageBufferTests, SizedArrayOfStruct) {
108     BasicTest(R"(
109         struct S {
110             a : vec2<u32>;
111             b : vec2<u32>;
112         };
113 
114         [[block]] struct Buf {
115             s : array<S, 4>;
116         };
117 
118         [[group(0), binding(0)]] var<storage, read_write> src : Buf;
119         [[group(0), binding(1)]] var<storage, read_write> dst : Buf;
120 
121         [[stage(compute), workgroup_size(1)]]
122         fn main([[builtin(global_invocation_id)]] GlobalInvocationID : vec3<u32>) {
123             let index : u32 = GlobalInvocationID.x;
124             if (index >= 4u) { return; }
125             dst.s[index] = src.s[index];
126         })");
127 }
128 
129 // Test that a trivial compute-shader memcpy implementation works.
TEST_P(ComputeCopyStorageBufferTests,UnsizedArrayOfBasic)130 TEST_P(ComputeCopyStorageBufferTests, UnsizedArrayOfBasic) {
131     BasicTest(R"(
132         [[block]] struct Buf {
133             s : array<vec4<u32>>;
134         };
135 
136         [[group(0), binding(0)]] var<storage, read_write> src : Buf;
137         [[group(0), binding(1)]] var<storage, read_write> dst : Buf;
138 
139         [[stage(compute), workgroup_size(1)]]
140         fn main([[builtin(global_invocation_id)]] GlobalInvocationID : vec3<u32>) {
141             let index : u32 = GlobalInvocationID.x;
142             if (index >= 4u) { return; }
143             dst.s[index] = src.s[index];
144         })");
145 }
146 
147 DAWN_INSTANTIATE_TEST(ComputeCopyStorageBufferTests,
148                       D3D12Backend(),
149                       MetalBackend(),
150                       OpenGLBackend(),
151                       OpenGLESBackend(),
152                       VulkanBackend());
153