• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2017 The Dawn Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "common/Constants.h"
16 #include "tests/unittests/validation/ValidationTest.h"
17 #include "utils/WGPUHelpers.h"
18 
19 // TODO(cwallez@chromium.org): Add a regression test for Disptach validation trying to acces the
20 // input state.
21 
22 class ComputeValidationTest : public ValidationTest {
23   protected:
SetUp()24     void SetUp() override {
25         ValidationTest::SetUp();
26 
27         wgpu::ShaderModule computeModule = utils::CreateShaderModule(device, R"(
28             [[stage(compute), workgroup_size(1)]] fn main() {
29             })");
30 
31         // Set up compute pipeline
32         wgpu::PipelineLayout pl = utils::MakeBasicPipelineLayout(device, nullptr);
33 
34         wgpu::ComputePipelineDescriptor csDesc;
35         csDesc.layout = pl;
36         csDesc.compute.module = computeModule;
37         csDesc.compute.entryPoint = "main";
38         pipeline = device.CreateComputePipeline(&csDesc);
39     }
40 
TestDispatch(uint32_t x,uint32_t y,uint32_t z)41     void TestDispatch(uint32_t x, uint32_t y, uint32_t z) {
42         wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
43         wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
44         pass.SetPipeline(pipeline);
45         pass.Dispatch(x, y, z);
46         pass.EndPass();
47         encoder.Finish();
48     }
49 
50     wgpu::ComputePipeline pipeline;
51 };
52 
53 // Check that 1x1x1 dispatch is OK.
TEST_F(ComputeValidationTest,PerDimensionDispatchSizeLimits_SmallestValid)54 TEST_F(ComputeValidationTest, PerDimensionDispatchSizeLimits_SmallestValid) {
55     TestDispatch(1, 1, 1);
56 }
57 
58 // Check that the largest allowed dispatch is OK.
TEST_F(ComputeValidationTest,PerDimensionDispatchSizeLimits_LargestValid)59 TEST_F(ComputeValidationTest, PerDimensionDispatchSizeLimits_LargestValid) {
60     const uint32_t max = GetSupportedLimits().limits.maxComputeWorkgroupsPerDimension;
61     TestDispatch(max, max, max);
62 }
63 
64 // Check that exceeding the maximum on the X dimension results in validation failure.
TEST_F(ComputeValidationTest,PerDimensionDispatchSizeLimits_InvalidX)65 TEST_F(ComputeValidationTest, PerDimensionDispatchSizeLimits_InvalidX) {
66     const uint32_t max = GetSupportedLimits().limits.maxComputeWorkgroupsPerDimension;
67     ASSERT_DEVICE_ERROR(TestDispatch(max + 1, 1, 1));
68 }
69 
70 // Check that exceeding the maximum on the Y dimension results in validation failure.
TEST_F(ComputeValidationTest,PerDimensionDispatchSizeLimits_InvalidY)71 TEST_F(ComputeValidationTest, PerDimensionDispatchSizeLimits_InvalidY) {
72     const uint32_t max = GetSupportedLimits().limits.maxComputeWorkgroupsPerDimension;
73     ASSERT_DEVICE_ERROR(TestDispatch(1, max + 1, 1));
74 }
75 
76 // Check that exceeding the maximum on the Z dimension results in validation failure.
TEST_F(ComputeValidationTest,PerDimensionDispatchSizeLimits_InvalidZ)77 TEST_F(ComputeValidationTest, PerDimensionDispatchSizeLimits_InvalidZ) {
78     const uint32_t max = GetSupportedLimits().limits.maxComputeWorkgroupsPerDimension;
79     ASSERT_DEVICE_ERROR(TestDispatch(1, 1, max + 1));
80 }
81 
82 // Check that exceeding the maximum on all dimensions results in validation failure.
TEST_F(ComputeValidationTest,PerDimensionDispatchSizeLimits_InvalidAll)83 TEST_F(ComputeValidationTest, PerDimensionDispatchSizeLimits_InvalidAll) {
84     const uint32_t max = GetSupportedLimits().limits.maxComputeWorkgroupsPerDimension;
85     ASSERT_DEVICE_ERROR(TestDispatch(max + 1, max + 1, max + 1));
86 }
87