1 // Copyright (c) 2019 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include <vector>
16
17 #include "test/opt/pass_fixture.h"
18 #include "test/opt/pass_utils.h"
19
20 namespace spvtools {
21 namespace opt {
22 namespace {
23
24 typedef std::tuple<std::string, bool> GenerateWebGPUInitializersParam;
25
26 using GlobalVariableTest =
27 PassTest<::testing::TestWithParam<GenerateWebGPUInitializersParam>>;
28 using LocalVariableTest =
29 PassTest<::testing::TestWithParam<GenerateWebGPUInitializersParam>>;
30
31 using GenerateWebGPUInitializersTest = PassTest<::testing::Test>;
32
operator +=(std::vector<const char * > & lhs,const char * rhs)33 void operator+=(std::vector<const char*>& lhs, const char* rhs) {
34 lhs.push_back(rhs);
35 }
36
operator +=(std::vector<const char * > & lhs,const std::vector<const char * > & rhs)37 void operator+=(std::vector<const char*>& lhs,
38 const std::vector<const char*>& rhs) {
39 lhs.reserve(lhs.size() + rhs.size());
40 for (auto* c : rhs) lhs.push_back(c);
41 }
42
GetGlobalVariableTestString(std::string ptr_str,std::string var_str,std::string const_str="")43 std::string GetGlobalVariableTestString(std::string ptr_str,
44 std::string var_str,
45 std::string const_str = "") {
46 std::vector<const char*> result = {
47 // clang-format off
48 "OpCapability Shader",
49 "OpCapability VulkanMemoryModel",
50 "OpExtension \"SPV_KHR_vulkan_memory_model\"",
51 "OpMemoryModel Logical Vulkan",
52 "OpEntryPoint Vertex %1 \"shader\"",
53 "%uint = OpTypeInt 32 0",
54 ptr_str.c_str()};
55 // clang-format on
56
57 if (!const_str.empty()) result += const_str.c_str();
58
59 result += {
60 // clang-format off
61 var_str.c_str(),
62 "%uint_0 = OpConstant %uint 0",
63 "%void = OpTypeVoid",
64 "%7 = OpTypeFunction %void",
65 "%1 = OpFunction %void None %7",
66 "%8 = OpLabel",
67 "OpStore %4 %uint_0",
68 "OpReturn",
69 "OpFunctionEnd"
70 // clang-format on
71 };
72 return JoinAllInsts(result);
73 }
74
GetPointerString(std::string storage_type)75 std::string GetPointerString(std::string storage_type) {
76 std::string result = "%_ptr_";
77 result += storage_type + "_uint = OpTypePointer ";
78 result += storage_type + " %uint";
79 return result;
80 }
81
GetGlobalVariableString(std::string storage_type,bool initialized)82 std::string GetGlobalVariableString(std::string storage_type,
83 bool initialized) {
84 std::string result = "%4 = OpVariable %_ptr_";
85 result += storage_type + "_uint ";
86 result += storage_type;
87 if (initialized) result += " %9";
88 return result;
89 }
90
GetUninitializedGlobalVariableTestString(std::string storage_type)91 std::string GetUninitializedGlobalVariableTestString(std::string storage_type) {
92 return GetGlobalVariableTestString(
93 GetPointerString(storage_type),
94 GetGlobalVariableString(storage_type, false));
95 }
96
GetNullConstantString()97 std::string GetNullConstantString() { return "%9 = OpConstantNull %uint"; }
98
GetInitializedGlobalVariableTestString(std::string storage_type)99 std::string GetInitializedGlobalVariableTestString(std::string storage_type) {
100 return GetGlobalVariableTestString(
101 GetPointerString(storage_type),
102 GetGlobalVariableString(storage_type, true), GetNullConstantString());
103 }
104
TEST_P(GlobalVariableTest,Check)105 TEST_P(GlobalVariableTest, Check) {
106 std::string storage_class = std::get<0>(GetParam());
107 bool changed = std::get<1>(GetParam());
108 std::string input = GetUninitializedGlobalVariableTestString(storage_class);
109 std::string expected =
110 changed ? GetInitializedGlobalVariableTestString(storage_class) : input;
111
112 SinglePassRunAndCheck<GenerateWebGPUInitializersPass>(input, expected,
113 /* skip_nop = */ false);
114 }
115
116 // clang-format off
117 INSTANTIATE_TEST_SUITE_P(
118 GenerateWebGPUInitializers, GlobalVariableTest,
119 ::testing::ValuesIn(std::vector<GenerateWebGPUInitializersParam>({
120 std::make_tuple("Private", true),
121 std::make_tuple("Output", true),
122 std::make_tuple("Function", true),
123 std::make_tuple("UniformConstant", false),
124 std::make_tuple("Input", false),
125 std::make_tuple("Uniform", false),
126 std::make_tuple("Workgroup", false)
127 })));
128 // clang-format on
129
GetLocalVariableTestString(std::string ptr_str,std::string var_str,std::string const_str="")130 std::string GetLocalVariableTestString(std::string ptr_str, std::string var_str,
131 std::string const_str = "") {
132 std::vector<const char*> result = {
133 // clang-format off
134 "OpCapability Shader",
135 "OpCapability VulkanMemoryModel",
136 "OpExtension \"SPV_KHR_vulkan_memory_model\"",
137 "OpMemoryModel Logical Vulkan",
138 "OpEntryPoint Vertex %1 \"shader\"",
139 "%uint = OpTypeInt 32 0",
140 ptr_str.c_str(),
141 "%uint_0 = OpConstant %uint 0",
142 "%void = OpTypeVoid",
143 "%6 = OpTypeFunction %void"};
144 // clang-format on
145
146 if (!const_str.empty()) result += const_str.c_str();
147
148 result += {
149 // clang-format off
150 "%1 = OpFunction %void None %6",
151 "%7 = OpLabel",
152 var_str.c_str(),
153 "OpStore %8 %uint_0"
154 // clang-format on
155 };
156 return JoinAllInsts(result);
157 }
158
GetLocalVariableString(std::string storage_type,bool initialized)159 std::string GetLocalVariableString(std::string storage_type, bool initialized) {
160 std::string result = "%8 = OpVariable %_ptr_";
161 result += storage_type + "_uint ";
162 result += storage_type;
163 if (initialized) result += " %9";
164 return result;
165 }
166
GetUninitializedLocalVariableTestString(std::string storage_type)167 std::string GetUninitializedLocalVariableTestString(std::string storage_type) {
168 return GetLocalVariableTestString(
169 GetPointerString(storage_type),
170 GetLocalVariableString(storage_type, false));
171 }
172
GetInitializedLocalVariableTestString(std::string storage_type)173 std::string GetInitializedLocalVariableTestString(std::string storage_type) {
174 return GetLocalVariableTestString(GetPointerString(storage_type),
175 GetLocalVariableString(storage_type, true),
176 GetNullConstantString());
177 }
178
TEST_P(LocalVariableTest,Check)179 TEST_P(LocalVariableTest, Check) {
180 std::string storage_class = std::get<0>(GetParam());
181 bool changed = std::get<1>(GetParam());
182
183 std::string input = GetUninitializedLocalVariableTestString(storage_class);
184 std::string expected =
185 changed ? GetInitializedLocalVariableTestString(storage_class) : input;
186
187 SinglePassRunAndCheck<GenerateWebGPUInitializersPass>(input, expected,
188 /* skip_nop = */ false);
189 }
190
191 // clang-format off
192 INSTANTIATE_TEST_SUITE_P(
193 GenerateWebGPUInitializers, LocalVariableTest,
194 ::testing::ValuesIn(std::vector<GenerateWebGPUInitializersParam>({
195 std::make_tuple("Private", true),
196 std::make_tuple("Output", true),
197 std::make_tuple("Function", true),
198 std::make_tuple("UniformConstant", false),
199 std::make_tuple("Input", false),
200 std::make_tuple("Uniform", false),
201 std::make_tuple("Workgroup", false)
202 })));
203 // clang-format on
204
TEST_F(GenerateWebGPUInitializersTest,AlreadyInitializedUnchanged)205 TEST_F(GenerateWebGPUInitializersTest, AlreadyInitializedUnchanged) {
206 std::vector<const char*> spirv = {
207 // clang-format off
208 "OpCapability Shader",
209 "OpCapability VulkanMemoryModel",
210 "OpExtension \"SPV_KHR_vulkan_memory_model\"",
211 "OpMemoryModel Logical Vulkan",
212 "OpEntryPoint Vertex %1 \"shader\"",
213 "%uint = OpTypeInt 32 0",
214 "%_ptr_Private_uint = OpTypePointer Private %uint",
215 "%uint_0 = OpConstant %uint 0",
216 "%5 = OpVariable %_ptr_Private_uint Private %uint_0",
217 "%void = OpTypeVoid",
218 "%7 = OpTypeFunction %void",
219 "%1 = OpFunction %void None %7",
220 "%8 = OpLabel",
221 "OpReturn",
222 "OpFunctionEnd"
223 // clang-format on
224 };
225 std::string str = JoinAllInsts(spirv);
226
227 SinglePassRunAndCheck<GenerateWebGPUInitializersPass>(str, str,
228 /* skip_nop = */ false);
229 }
230
TEST_F(GenerateWebGPUInitializersTest,AmbigiousArrays)231 TEST_F(GenerateWebGPUInitializersTest, AmbigiousArrays) {
232 std::vector<const char*> input_spirv = {
233 // clang-format off
234 "OpCapability Shader",
235 "OpCapability VulkanMemoryModel",
236 "OpExtension \"SPV_KHR_vulkan_memory_model\"",
237 "OpMemoryModel Logical Vulkan",
238 "OpEntryPoint Vertex %1 \"shader\"",
239 "%uint = OpTypeInt 32 0",
240 "%uint_2 = OpConstant %uint 2",
241 "%_arr_uint_uint_2 = OpTypeArray %uint %uint_2",
242 "%_arr_uint_uint_2_0 = OpTypeArray %uint %uint_2",
243 "%_ptr_Private__arr_uint_uint_2 = OpTypePointer Private %_arr_uint_uint_2",
244 "%_ptr_Private__arr_uint_uint_2_0 = OpTypePointer Private %_arr_uint_uint_2_0",
245 "%8 = OpConstantNull %_arr_uint_uint_2_0",
246 "%9 = OpVariable %_ptr_Private__arr_uint_uint_2 Private",
247 "%10 = OpVariable %_ptr_Private__arr_uint_uint_2_0 Private %8",
248 "%void = OpTypeVoid",
249 "%12 = OpTypeFunction %void",
250 "%1 = OpFunction %void None %12",
251 "%13 = OpLabel",
252 "OpReturn",
253 "OpFunctionEnd"
254 // clang-format on
255 };
256 std::string input_str = JoinAllInsts(input_spirv);
257
258 std::vector<const char*> expected_spirv = {
259 // clang-format off
260 "OpCapability Shader",
261 "OpCapability VulkanMemoryModel",
262 "OpExtension \"SPV_KHR_vulkan_memory_model\"",
263 "OpMemoryModel Logical Vulkan",
264 "OpEntryPoint Vertex %1 \"shader\"",
265 "%uint = OpTypeInt 32 0",
266 "%uint_2 = OpConstant %uint 2",
267 "%_arr_uint_uint_2 = OpTypeArray %uint %uint_2",
268 "%_arr_uint_uint_2_0 = OpTypeArray %uint %uint_2",
269 "%_ptr_Private__arr_uint_uint_2 = OpTypePointer Private %_arr_uint_uint_2",
270 "%_ptr_Private__arr_uint_uint_2_0 = OpTypePointer Private %_arr_uint_uint_2_0",
271 "%8 = OpConstantNull %_arr_uint_uint_2_0",
272 "%14 = OpConstantNull %_arr_uint_uint_2",
273 "%9 = OpVariable %_ptr_Private__arr_uint_uint_2 Private %14",
274 "%10 = OpVariable %_ptr_Private__arr_uint_uint_2_0 Private %8",
275 "%void = OpTypeVoid",
276 "%12 = OpTypeFunction %void",
277 "%1 = OpFunction %void None %12",
278 "%13 = OpLabel",
279 "OpReturn",
280 "OpFunctionEnd"
281 // clang-format on
282 };
283 std::string expected_str = JoinAllInsts(expected_spirv);
284
285 SinglePassRunAndCheck<GenerateWebGPUInitializersPass>(input_str, expected_str,
286 /* skip_nop = */ false);
287 }
288
TEST_F(GenerateWebGPUInitializersTest,AmbigiousStructs)289 TEST_F(GenerateWebGPUInitializersTest, AmbigiousStructs) {
290 std::vector<const char*> input_spirv = {
291 // clang-format off
292 "OpCapability Shader",
293 "OpCapability VulkanMemoryModel",
294 "OpExtension \"SPV_KHR_vulkan_memory_model\"",
295 "OpMemoryModel Logical Vulkan",
296 "OpEntryPoint Vertex %1 \"shader\"",
297 "%uint = OpTypeInt 32 0",
298 "%_struct_3 = OpTypeStruct %uint",
299 "%_struct_4 = OpTypeStruct %uint",
300 "%_ptr_Private__struct_3 = OpTypePointer Private %_struct_3",
301 "%_ptr_Private__struct_4 = OpTypePointer Private %_struct_4",
302 "%7 = OpConstantNull %_struct_3",
303 "%8 = OpVariable %_ptr_Private__struct_3 Private %7",
304 "%9 = OpVariable %_ptr_Private__struct_4 Private",
305 "%void = OpTypeVoid",
306 "%11 = OpTypeFunction %void",
307 "%1 = OpFunction %void None %11",
308 "%12 = OpLabel",
309 "OpReturn",
310 "OpFunctionEnd"
311 // clang-format on
312 };
313 std::string input_str = JoinAllInsts(input_spirv);
314
315 std::vector<const char*> expected_spirv = {
316 // clang-format off
317 "OpCapability Shader",
318 "OpCapability VulkanMemoryModel",
319 "OpExtension \"SPV_KHR_vulkan_memory_model\"",
320 "OpMemoryModel Logical Vulkan",
321 "OpEntryPoint Vertex %1 \"shader\"",
322 "%uint = OpTypeInt 32 0",
323 "%_struct_3 = OpTypeStruct %uint",
324 "%_struct_4 = OpTypeStruct %uint",
325 "%_ptr_Private__struct_3 = OpTypePointer Private %_struct_3",
326 "%_ptr_Private__struct_4 = OpTypePointer Private %_struct_4",
327 "%7 = OpConstantNull %_struct_3",
328 "%8 = OpVariable %_ptr_Private__struct_3 Private %7",
329 "%13 = OpConstantNull %_struct_4",
330 "%9 = OpVariable %_ptr_Private__struct_4 Private %13",
331 "%void = OpTypeVoid",
332 "%11 = OpTypeFunction %void",
333 "%1 = OpFunction %void None %11",
334 "%12 = OpLabel",
335 "OpReturn",
336 "OpFunctionEnd"
337 // clang-format on
338 };
339 std::string expected_str = JoinAllInsts(expected_spirv);
340
341 SinglePassRunAndCheck<GenerateWebGPUInitializersPass>(input_str, expected_str,
342 /* skip_nop = */ false);
343 }
344
345 } // namespace
346 } // namespace opt
347 } // namespace spvtools
348