1 // Copyright (c) 2018 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include <memory>
16 #include <sstream>
17 #include <string>
18 #include <vector>
19
20 #include "function_utils.h"
21 #include "gmock/gmock.h"
22 #include "gtest/gtest.h"
23 #include "source/opt/build_module.h"
24 #include "source/opt/ir_context.h"
25
26 namespace spvtools {
27 namespace opt {
28 namespace {
29
30 using ::testing::Eq;
31
TEST(FunctionTest,HasEarlyReturn)32 TEST(FunctionTest, HasEarlyReturn) {
33 std::string shader = R"(
34 OpCapability Shader
35 %1 = OpExtInstImport "GLSL.std.450"
36 OpMemoryModel Logical GLSL450
37 OpEntryPoint Vertex %6 "main"
38
39 ; Types
40 %2 = OpTypeBool
41 %3 = OpTypeVoid
42 %4 = OpTypeFunction %3
43
44 ; Constants
45 %5 = OpConstantTrue %2
46
47 ; main function without early return
48 %6 = OpFunction %3 None %4
49 %7 = OpLabel
50 OpBranch %8
51 %8 = OpLabel
52 OpBranch %9
53 %9 = OpLabel
54 OpBranch %10
55 %10 = OpLabel
56 OpReturn
57 OpFunctionEnd
58
59 ; function with early return
60 %11 = OpFunction %3 None %4
61 %12 = OpLabel
62 OpSelectionMerge %15 None
63 OpBranchConditional %5 %13 %14
64 %13 = OpLabel
65 OpReturn
66 %14 = OpLabel
67 OpBranch %15
68 %15 = OpLabel
69 OpReturn
70 OpFunctionEnd
71 )";
72
73 const auto context =
74 BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, shader,
75 SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
76
77 // Tests |function| without early return.
78 auto* function = spvtest::GetFunction(context->module(), 6);
79 ASSERT_FALSE(function->HasEarlyReturn());
80
81 // Tests |function| with early return.
82 function = spvtest::GetFunction(context->module(), 11);
83 ASSERT_TRUE(function->HasEarlyReturn());
84 }
85
TEST(FunctionTest,IsNotRecursive)86 TEST(FunctionTest, IsNotRecursive) {
87 const std::string text = R"(
88 OpCapability Shader
89 OpMemoryModel Logical GLSL450
90 OpEntryPoint Fragment %1 "main"
91 OpExecutionMode %1 OriginUpperLeft
92 OpDecorate %2 DescriptorSet 439418829
93 %void = OpTypeVoid
94 %4 = OpTypeFunction %void
95 %float = OpTypeFloat 32
96 %_struct_6 = OpTypeStruct %float %float
97 %7 = OpTypeFunction %_struct_6
98 %1 = OpFunction %void Pure|Const %4
99 %8 = OpLabel
100 %2 = OpFunctionCall %_struct_6 %9
101 OpKill
102 OpFunctionEnd
103 %9 = OpFunction %_struct_6 None %7
104 %10 = OpLabel
105 %11 = OpFunctionCall %_struct_6 %12
106 OpUnreachable
107 OpFunctionEnd
108 %12 = OpFunction %_struct_6 None %7
109 %13 = OpLabel
110 OpUnreachable
111 OpFunctionEnd
112 )";
113
114 std::unique_ptr<IRContext> ctx =
115 spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
116 SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
117 auto* func = spvtest::GetFunction(ctx->module(), 9);
118 EXPECT_FALSE(func->IsRecursive());
119
120 func = spvtest::GetFunction(ctx->module(), 12);
121 EXPECT_FALSE(func->IsRecursive());
122 }
123
TEST(FunctionTest,IsDirectlyRecursive)124 TEST(FunctionTest, IsDirectlyRecursive) {
125 const std::string text = R"(
126 OpCapability Shader
127 OpMemoryModel Logical GLSL450
128 OpEntryPoint Fragment %1 "main"
129 OpExecutionMode %1 OriginUpperLeft
130 OpDecorate %2 DescriptorSet 439418829
131 %void = OpTypeVoid
132 %4 = OpTypeFunction %void
133 %float = OpTypeFloat 32
134 %_struct_6 = OpTypeStruct %float %float
135 %7 = OpTypeFunction %_struct_6
136 %1 = OpFunction %void Pure|Const %4
137 %8 = OpLabel
138 %2 = OpFunctionCall %_struct_6 %9
139 OpKill
140 OpFunctionEnd
141 %9 = OpFunction %_struct_6 None %7
142 %10 = OpLabel
143 %11 = OpFunctionCall %_struct_6 %9
144 OpUnreachable
145 OpFunctionEnd
146 )";
147
148 std::unique_ptr<IRContext> ctx =
149 spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
150 SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
151 auto* func = spvtest::GetFunction(ctx->module(), 9);
152 EXPECT_TRUE(func->IsRecursive());
153 }
154
TEST(FunctionTest,IsIndirectlyRecursive)155 TEST(FunctionTest, IsIndirectlyRecursive) {
156 const std::string text = R"(
157 OpCapability Shader
158 OpMemoryModel Logical GLSL450
159 OpEntryPoint Fragment %1 "main"
160 OpExecutionMode %1 OriginUpperLeft
161 OpDecorate %2 DescriptorSet 439418829
162 %void = OpTypeVoid
163 %4 = OpTypeFunction %void
164 %float = OpTypeFloat 32
165 %_struct_6 = OpTypeStruct %float %float
166 %7 = OpTypeFunction %_struct_6
167 %1 = OpFunction %void Pure|Const %4
168 %8 = OpLabel
169 %2 = OpFunctionCall %_struct_6 %9
170 OpKill
171 OpFunctionEnd
172 %9 = OpFunction %_struct_6 None %7
173 %10 = OpLabel
174 %11 = OpFunctionCall %_struct_6 %12
175 OpUnreachable
176 OpFunctionEnd
177 %12 = OpFunction %_struct_6 None %7
178 %13 = OpLabel
179 %14 = OpFunctionCall %_struct_6 %9
180 OpUnreachable
181 OpFunctionEnd
182 )";
183
184 std::unique_ptr<IRContext> ctx =
185 spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
186 SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
187 auto* func = spvtest::GetFunction(ctx->module(), 9);
188 EXPECT_TRUE(func->IsRecursive());
189
190 func = spvtest::GetFunction(ctx->module(), 12);
191 EXPECT_TRUE(func->IsRecursive());
192 }
193
TEST(FunctionTest,IsNotRecuriseCallingRecursive)194 TEST(FunctionTest, IsNotRecuriseCallingRecursive) {
195 const std::string text = R"(
196 OpCapability Shader
197 OpMemoryModel Logical GLSL450
198 OpEntryPoint Fragment %1 "main"
199 OpExecutionMode %1 OriginUpperLeft
200 OpDecorate %2 DescriptorSet 439418829
201 %void = OpTypeVoid
202 %4 = OpTypeFunction %void
203 %float = OpTypeFloat 32
204 %_struct_6 = OpTypeStruct %float %float
205 %7 = OpTypeFunction %_struct_6
206 %1 = OpFunction %void Pure|Const %4
207 %8 = OpLabel
208 %2 = OpFunctionCall %_struct_6 %9
209 OpKill
210 OpFunctionEnd
211 %9 = OpFunction %_struct_6 None %7
212 %10 = OpLabel
213 %11 = OpFunctionCall %_struct_6 %9
214 OpUnreachable
215 OpFunctionEnd
216 )";
217
218 std::unique_ptr<IRContext> ctx =
219 spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
220 SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
221 auto* func = spvtest::GetFunction(ctx->module(), 1);
222 EXPECT_FALSE(func->IsRecursive());
223 }
224
TEST(FunctionTest,NonSemanticInfoSkipIteration)225 TEST(FunctionTest, NonSemanticInfoSkipIteration) {
226 const std::string text = R"(
227 OpCapability Shader
228 OpCapability Linkage
229 OpExtension "SPV_KHR_non_semantic_info"
230 %1 = OpExtInstImport "NonSemantic.Test"
231 OpMemoryModel Logical GLSL450
232 %2 = OpTypeVoid
233 %3 = OpTypeFunction %2
234 %4 = OpFunction %2 None %3
235 %5 = OpLabel
236 %6 = OpExtInst %2 %1 1
237 OpReturn
238 OpFunctionEnd
239 %7 = OpExtInst %2 %1 2
240 %8 = OpExtInst %2 %1 3
241 )";
242
243 std::unique_ptr<IRContext> ctx =
244 spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
245 SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
246 auto* func = spvtest::GetFunction(ctx->module(), 4);
247 ASSERT_TRUE(func != nullptr);
248 std::unordered_set<uint32_t> non_semantic_ids;
249 func->ForEachInst(
250 [&non_semantic_ids](const Instruction* inst) {
251 if (inst->opcode() == spv::Op::OpExtInst) {
252 non_semantic_ids.insert(inst->result_id());
253 }
254 },
255 true, false);
256
257 EXPECT_EQ(1, non_semantic_ids.count(6));
258 EXPECT_EQ(0, non_semantic_ids.count(7));
259 EXPECT_EQ(0, non_semantic_ids.count(8));
260 }
261
TEST(FunctionTest,NonSemanticInfoIncludeIteration)262 TEST(FunctionTest, NonSemanticInfoIncludeIteration) {
263 const std::string text = R"(
264 OpCapability Shader
265 OpCapability Linkage
266 OpExtension "SPV_KHR_non_semantic_info"
267 %1 = OpExtInstImport "NonSemantic.Test"
268 OpMemoryModel Logical GLSL450
269 %2 = OpTypeVoid
270 %3 = OpTypeFunction %2
271 %4 = OpFunction %2 None %3
272 %5 = OpLabel
273 %6 = OpExtInst %2 %1 1
274 OpReturn
275 OpFunctionEnd
276 %7 = OpExtInst %2 %1 2
277 %8 = OpExtInst %2 %1 3
278 )";
279
280 std::unique_ptr<IRContext> ctx =
281 spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
282 SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
283 auto* func = spvtest::GetFunction(ctx->module(), 4);
284 ASSERT_TRUE(func != nullptr);
285 std::unordered_set<uint32_t> non_semantic_ids;
286 func->ForEachInst(
287 [&non_semantic_ids](const Instruction* inst) {
288 if (inst->opcode() == spv::Op::OpExtInst) {
289 non_semantic_ids.insert(inst->result_id());
290 }
291 },
292 true, true);
293
294 EXPECT_EQ(1, non_semantic_ids.count(6));
295 EXPECT_EQ(1, non_semantic_ids.count(7));
296 EXPECT_EQ(1, non_semantic_ids.count(8));
297 }
298
TEST(FunctionTest,ReorderBlocksinStructuredOrder)299 TEST(FunctionTest, ReorderBlocksinStructuredOrder) {
300 // The spir-v has the basic block in a random order. We want to reorder them
301 // in structured order.
302 const std::string text = R"(
303 OpCapability Shader
304 OpMemoryModel Logical GLSL450
305 OpEntryPoint Fragment %100 "PSMain"
306 OpExecutionMode %PSMain OriginUpperLeft
307 OpSource HLSL 600
308 %int = OpTypeInt 32 1
309 %void = OpTypeVoid
310 %19 = OpTypeFunction %void
311 %bool = OpTypeBool
312 %undef_bool = OpUndef %bool
313 %undef_int = OpUndef %int
314 %100 = OpFunction %void None %19
315 %11 = OpLabel
316 OpSelectionMerge %10 None
317 OpSwitch %undef_int %3 0 %2 10 %1
318 %2 = OpLabel
319 OpReturn
320 %7 = OpLabel
321 OpBranch %8
322 %3 = OpLabel
323 OpBranch %4
324 %10 = OpLabel
325 OpReturn
326 %9 = OpLabel
327 OpBranch %10
328 %8 = OpLabel
329 OpBranch %4
330 %4 = OpLabel
331 OpLoopMerge %9 %8 None
332 OpBranchConditional %undef_bool %5 %9
333 %1 = OpLabel
334 OpReturn
335 %6 = OpLabel
336 OpBranch %7
337 %5 = OpLabel
338 OpSelectionMerge %7 None
339 OpBranchConditional %undef_bool %6 %7
340 OpFunctionEnd
341 )";
342
343 std::unique_ptr<IRContext> ctx =
344 spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
345 SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
346 ASSERT_TRUE(ctx);
347 auto* func = spvtest::GetFunction(ctx->module(), 100);
348 ASSERT_TRUE(func);
349 func->ReorderBasicBlocksInStructuredOrder();
350
351 auto first_block = func->begin();
352 auto bb = first_block;
353 for (++bb; bb != func->end(); ++bb) {
354 EXPECT_EQ(bb->id(), (bb - first_block));
355 }
356 }
357
358 } // namespace
359 } // namespace opt
360 } // namespace spvtools
361