• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2019 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <vector>
16 
17 #include "test/opt/pass_fixture.h"
18 #include "test/opt/pass_utils.h"
19 
20 namespace spvtools {
21 namespace opt {
22 namespace {
23 
24 typedef std::tuple<std::string, bool> GenerateWebGPUInitializersParam;
25 
26 using GlobalVariableTest =
27     PassTest<::testing::TestWithParam<GenerateWebGPUInitializersParam>>;
28 using LocalVariableTest =
29     PassTest<::testing::TestWithParam<GenerateWebGPUInitializersParam>>;
30 
31 using GenerateWebGPUInitializersTest = PassTest<::testing::Test>;
32 
operator +=(std::vector<const char * > & lhs,const char * rhs)33 void operator+=(std::vector<const char*>& lhs, const char* rhs) {
34   lhs.push_back(rhs);
35 }
36 
operator +=(std::vector<const char * > & lhs,const std::vector<const char * > & rhs)37 void operator+=(std::vector<const char*>& lhs,
38                 const std::vector<const char*>& rhs) {
39   lhs.reserve(lhs.size() + rhs.size());
40   for (auto* c : rhs) lhs.push_back(c);
41 }
42 
GetGlobalVariableTestString(std::string ptr_str,std::string var_str,std::string const_str="")43 std::string GetGlobalVariableTestString(std::string ptr_str,
44                                         std::string var_str,
45                                         std::string const_str = "") {
46   std::vector<const char*> result = {
47       // clang-format off
48                "OpCapability Shader",
49                "OpCapability VulkanMemoryModel",
50                "OpExtension \"SPV_KHR_vulkan_memory_model\"",
51                "OpMemoryModel Logical Vulkan",
52                "OpEntryPoint Vertex %1 \"shader\"",
53        "%uint = OpTypeInt 32 0",
54                 ptr_str.c_str()};
55   // clang-format on
56 
57   if (!const_str.empty()) result += const_str.c_str();
58 
59   result += {
60       // clang-format off
61                 var_str.c_str(),
62      "%uint_0 = OpConstant %uint 0",
63        "%void = OpTypeVoid",
64           "%7 = OpTypeFunction %void",
65           "%1 = OpFunction %void None %7",
66           "%8 = OpLabel",
67                "OpStore %4 %uint_0",
68                "OpReturn",
69                "OpFunctionEnd"
70       // clang-format on
71   };
72   return JoinAllInsts(result);
73 }
74 
GetPointerString(std::string storage_type)75 std::string GetPointerString(std::string storage_type) {
76   std::string result = "%_ptr_";
77   result += storage_type + "_uint = OpTypePointer ";
78   result += storage_type + " %uint";
79   return result;
80 }
81 
GetGlobalVariableString(std::string storage_type,bool initialized)82 std::string GetGlobalVariableString(std::string storage_type,
83                                     bool initialized) {
84   std::string result = "%4 = OpVariable %_ptr_";
85   result += storage_type + "_uint ";
86   result += storage_type;
87   if (initialized) result += " %9";
88   return result;
89 }
90 
GetUninitializedGlobalVariableTestString(std::string storage_type)91 std::string GetUninitializedGlobalVariableTestString(std::string storage_type) {
92   return GetGlobalVariableTestString(
93       GetPointerString(storage_type),
94       GetGlobalVariableString(storage_type, false));
95 }
96 
GetNullConstantString()97 std::string GetNullConstantString() { return "%9 = OpConstantNull %uint"; }
98 
GetInitializedGlobalVariableTestString(std::string storage_type)99 std::string GetInitializedGlobalVariableTestString(std::string storage_type) {
100   return GetGlobalVariableTestString(
101       GetPointerString(storage_type),
102       GetGlobalVariableString(storage_type, true), GetNullConstantString());
103 }
104 
TEST_P(GlobalVariableTest,Check)105 TEST_P(GlobalVariableTest, Check) {
106   std::string storage_class = std::get<0>(GetParam());
107   bool changed = std::get<1>(GetParam());
108   std::string input = GetUninitializedGlobalVariableTestString(storage_class);
109   std::string expected =
110       changed ? GetInitializedGlobalVariableTestString(storage_class) : input;
111 
112   SinglePassRunAndCheck<GenerateWebGPUInitializersPass>(input, expected,
113                                                         /* skip_nop = */ false);
114 }
115 
116 // clang-format off
117 INSTANTIATE_TEST_SUITE_P(
118     GenerateWebGPUInitializers, GlobalVariableTest,
119     ::testing::ValuesIn(std::vector<GenerateWebGPUInitializersParam>({
120        std::make_tuple("Private", true),
121        std::make_tuple("Output", true),
122        std::make_tuple("Function", true),
123        std::make_tuple("UniformConstant", false),
124        std::make_tuple("Input", false),
125        std::make_tuple("Uniform", false),
126        std::make_tuple("Workgroup", false)
127     })));
128 // clang-format on
129 
GetLocalVariableTestString(std::string ptr_str,std::string var_str,std::string const_str="")130 std::string GetLocalVariableTestString(std::string ptr_str, std::string var_str,
131                                        std::string const_str = "") {
132   std::vector<const char*> result = {
133       // clang-format off
134                "OpCapability Shader",
135                "OpCapability VulkanMemoryModel",
136                "OpExtension \"SPV_KHR_vulkan_memory_model\"",
137                "OpMemoryModel Logical Vulkan",
138                "OpEntryPoint Vertex %1 \"shader\"",
139        "%uint = OpTypeInt 32 0",
140                 ptr_str.c_str(),
141      "%uint_0 = OpConstant %uint 0",
142        "%void = OpTypeVoid",
143           "%6 = OpTypeFunction %void"};
144   // clang-format on
145 
146   if (!const_str.empty()) result += const_str.c_str();
147 
148   result += {
149       // clang-format off
150           "%1 = OpFunction %void None %6",
151           "%7 = OpLabel",
152                 var_str.c_str(),
153                "OpStore %8 %uint_0"
154       // clang-format on
155   };
156   return JoinAllInsts(result);
157 }
158 
GetLocalVariableString(std::string storage_type,bool initialized)159 std::string GetLocalVariableString(std::string storage_type, bool initialized) {
160   std::string result = "%8 = OpVariable %_ptr_";
161   result += storage_type + "_uint ";
162   result += storage_type;
163   if (initialized) result += " %9";
164   return result;
165 }
166 
GetUninitializedLocalVariableTestString(std::string storage_type)167 std::string GetUninitializedLocalVariableTestString(std::string storage_type) {
168   return GetLocalVariableTestString(
169       GetPointerString(storage_type),
170       GetLocalVariableString(storage_type, false));
171 }
172 
GetInitializedLocalVariableTestString(std::string storage_type)173 std::string GetInitializedLocalVariableTestString(std::string storage_type) {
174   return GetLocalVariableTestString(GetPointerString(storage_type),
175                                     GetLocalVariableString(storage_type, true),
176                                     GetNullConstantString());
177 }
178 
TEST_P(LocalVariableTest,Check)179 TEST_P(LocalVariableTest, Check) {
180   std::string storage_class = std::get<0>(GetParam());
181   bool changed = std::get<1>(GetParam());
182 
183   std::string input = GetUninitializedLocalVariableTestString(storage_class);
184   std::string expected =
185       changed ? GetInitializedLocalVariableTestString(storage_class) : input;
186 
187   SinglePassRunAndCheck<GenerateWebGPUInitializersPass>(input, expected,
188                                                         /* skip_nop = */ false);
189 }
190 
191 // clang-format off
192 INSTANTIATE_TEST_SUITE_P(
193     GenerateWebGPUInitializers, LocalVariableTest,
194     ::testing::ValuesIn(std::vector<GenerateWebGPUInitializersParam>({
195        std::make_tuple("Private", true),
196        std::make_tuple("Output", true),
197        std::make_tuple("Function", true),
198        std::make_tuple("UniformConstant", false),
199        std::make_tuple("Input", false),
200        std::make_tuple("Uniform", false),
201        std::make_tuple("Workgroup", false)
202     })));
203 // clang-format on
204 
TEST_F(GenerateWebGPUInitializersTest,AlreadyInitializedUnchanged)205 TEST_F(GenerateWebGPUInitializersTest, AlreadyInitializedUnchanged) {
206   std::vector<const char*> spirv = {
207       // clang-format off
208                        "OpCapability Shader",
209                        "OpCapability VulkanMemoryModel",
210                        "OpExtension \"SPV_KHR_vulkan_memory_model\"",
211                        "OpMemoryModel Logical Vulkan",
212                        "OpEntryPoint Vertex %1 \"shader\"",
213                "%uint = OpTypeInt 32 0",
214   "%_ptr_Private_uint = OpTypePointer Private %uint",
215              "%uint_0 = OpConstant %uint 0",
216                   "%5 = OpVariable %_ptr_Private_uint Private %uint_0",
217                "%void = OpTypeVoid",
218                   "%7 = OpTypeFunction %void",
219                   "%1 = OpFunction %void None %7",
220                   "%8 = OpLabel",
221                        "OpReturn",
222                        "OpFunctionEnd"
223       // clang-format on
224   };
225   std::string str = JoinAllInsts(spirv);
226 
227   SinglePassRunAndCheck<GenerateWebGPUInitializersPass>(str, str,
228                                                         /* skip_nop = */ false);
229 }
230 
TEST_F(GenerateWebGPUInitializersTest,AmbigiousArrays)231 TEST_F(GenerateWebGPUInitializersTest, AmbigiousArrays) {
232   std::vector<const char*> input_spirv = {
233       // clang-format off
234                                    "OpCapability Shader",
235                                    "OpCapability VulkanMemoryModel",
236                                    "OpExtension \"SPV_KHR_vulkan_memory_model\"",
237                                    "OpMemoryModel Logical Vulkan",
238                                    "OpEntryPoint Vertex %1 \"shader\"",
239                            "%uint = OpTypeInt 32 0",
240                          "%uint_2 = OpConstant %uint 2",
241                "%_arr_uint_uint_2 = OpTypeArray %uint %uint_2",
242              "%_arr_uint_uint_2_0 = OpTypeArray %uint %uint_2",
243   "%_ptr_Private__arr_uint_uint_2 = OpTypePointer Private %_arr_uint_uint_2",
244 "%_ptr_Private__arr_uint_uint_2_0 = OpTypePointer Private %_arr_uint_uint_2_0",
245                               "%8 = OpConstantNull %_arr_uint_uint_2_0",
246                               "%9 = OpVariable %_ptr_Private__arr_uint_uint_2 Private",
247                              "%10 = OpVariable %_ptr_Private__arr_uint_uint_2_0 Private %8",
248                            "%void = OpTypeVoid",
249                              "%12 = OpTypeFunction %void",
250                               "%1 = OpFunction %void None %12",
251                              "%13 = OpLabel",
252                                    "OpReturn",
253                                    "OpFunctionEnd"
254       // clang-format on
255   };
256   std::string input_str = JoinAllInsts(input_spirv);
257 
258   std::vector<const char*> expected_spirv = {
259       // clang-format off
260                                    "OpCapability Shader",
261                                    "OpCapability VulkanMemoryModel",
262                                    "OpExtension \"SPV_KHR_vulkan_memory_model\"",
263                                    "OpMemoryModel Logical Vulkan",
264                                    "OpEntryPoint Vertex %1 \"shader\"",
265                            "%uint = OpTypeInt 32 0",
266                          "%uint_2 = OpConstant %uint 2",
267                "%_arr_uint_uint_2 = OpTypeArray %uint %uint_2",
268              "%_arr_uint_uint_2_0 = OpTypeArray %uint %uint_2",
269   "%_ptr_Private__arr_uint_uint_2 = OpTypePointer Private %_arr_uint_uint_2",
270 "%_ptr_Private__arr_uint_uint_2_0 = OpTypePointer Private %_arr_uint_uint_2_0",
271                               "%8 = OpConstantNull %_arr_uint_uint_2_0",
272                              "%14 = OpConstantNull %_arr_uint_uint_2",
273                               "%9 = OpVariable %_ptr_Private__arr_uint_uint_2 Private %14",
274                              "%10 = OpVariable %_ptr_Private__arr_uint_uint_2_0 Private %8",
275                            "%void = OpTypeVoid",
276                              "%12 = OpTypeFunction %void",
277                               "%1 = OpFunction %void None %12",
278                              "%13 = OpLabel",
279                                    "OpReturn",
280                                    "OpFunctionEnd"
281       // clang-format on
282   };
283   std::string expected_str = JoinAllInsts(expected_spirv);
284 
285   SinglePassRunAndCheck<GenerateWebGPUInitializersPass>(input_str, expected_str,
286                                                         /* skip_nop = */ false);
287 }
288 
TEST_F(GenerateWebGPUInitializersTest,AmbigiousStructs)289 TEST_F(GenerateWebGPUInitializersTest, AmbigiousStructs) {
290   std::vector<const char*> input_spirv = {
291       // clang-format off
292                           "OpCapability Shader",
293                           "OpCapability VulkanMemoryModel",
294                           "OpExtension \"SPV_KHR_vulkan_memory_model\"",
295                           "OpMemoryModel Logical Vulkan",
296                           "OpEntryPoint Vertex %1 \"shader\"",
297                   "%uint = OpTypeInt 32 0",
298              "%_struct_3 = OpTypeStruct %uint",
299              "%_struct_4 = OpTypeStruct %uint",
300 "%_ptr_Private__struct_3 = OpTypePointer Private %_struct_3",
301 "%_ptr_Private__struct_4 = OpTypePointer Private %_struct_4",
302                      "%7 = OpConstantNull %_struct_3",
303                      "%8 = OpVariable %_ptr_Private__struct_3 Private %7",
304                      "%9 = OpVariable %_ptr_Private__struct_4 Private",
305                   "%void = OpTypeVoid",
306                     "%11 = OpTypeFunction %void",
307                      "%1 = OpFunction %void None %11",
308                     "%12 = OpLabel",
309                           "OpReturn",
310                           "OpFunctionEnd"
311       // clang-format on
312   };
313   std::string input_str = JoinAllInsts(input_spirv);
314 
315   std::vector<const char*> expected_spirv = {
316       // clang-format off
317                           "OpCapability Shader",
318                           "OpCapability VulkanMemoryModel",
319                           "OpExtension \"SPV_KHR_vulkan_memory_model\"",
320                           "OpMemoryModel Logical Vulkan",
321                           "OpEntryPoint Vertex %1 \"shader\"",
322                   "%uint = OpTypeInt 32 0",
323              "%_struct_3 = OpTypeStruct %uint",
324              "%_struct_4 = OpTypeStruct %uint",
325 "%_ptr_Private__struct_3 = OpTypePointer Private %_struct_3",
326 "%_ptr_Private__struct_4 = OpTypePointer Private %_struct_4",
327                      "%7 = OpConstantNull %_struct_3",
328                      "%8 = OpVariable %_ptr_Private__struct_3 Private %7",
329                     "%13 = OpConstantNull %_struct_4",
330                      "%9 = OpVariable %_ptr_Private__struct_4 Private %13",
331                   "%void = OpTypeVoid",
332                     "%11 = OpTypeFunction %void",
333                      "%1 = OpFunction %void None %11",
334                     "%12 = OpLabel",
335                           "OpReturn",
336                           "OpFunctionEnd"
337       // clang-format on
338   };
339   std::string expected_str = JoinAllInsts(expected_spirv);
340 
341   SinglePassRunAndCheck<GenerateWebGPUInitializersPass>(input_str, expected_str,
342                                                         /* skip_nop = */ false);
343 }
344 
345 }  // namespace
346 }  // namespace opt
347 }  // namespace spvtools
348