1 // Copyright (c) 2020 André Perez Maselco
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/fuzz/transformation_adjust_branch_weights.h"
16
17 #include "gtest/gtest.h"
18 #include "source/fuzz/fuzzer_util.h"
19 #include "source/fuzz/instruction_descriptor.h"
20 #include "test/fuzz/fuzz_test_util.h"
21
22 namespace spvtools {
23 namespace fuzz {
24 namespace {
25
TEST(TransformationAdjustBranchWeightsTest,IsApplicableTest)26 TEST(TransformationAdjustBranchWeightsTest, IsApplicableTest) {
27 std::string shader = R"(
28 OpCapability Shader
29 %1 = OpExtInstImport "GLSL.std.450"
30 OpMemoryModel Logical GLSL450
31 OpEntryPoint Fragment %4 "main" %51 %27
32 OpExecutionMode %4 OriginUpperLeft
33 OpSource ESSL 310
34 OpName %4 "main"
35 OpName %25 "buf"
36 OpMemberName %25 0 "value"
37 OpName %27 ""
38 OpName %51 "color"
39 OpMemberDecorate %25 0 Offset 0
40 OpDecorate %25 Block
41 OpDecorate %27 DescriptorSet 0
42 OpDecorate %27 Binding 0
43 OpDecorate %51 Location 0
44 %2 = OpTypeVoid
45 %3 = OpTypeFunction %2
46 %6 = OpTypeFloat 32
47 %7 = OpTypeVector %6 4
48 %150 = OpTypeVector %6 2
49 %10 = OpConstant %6 0.300000012
50 %11 = OpConstant %6 0.400000006
51 %12 = OpConstant %6 0.5
52 %13 = OpConstant %6 1
53 %14 = OpConstantComposite %7 %10 %11 %12 %13
54 %15 = OpTypeInt 32 1
55 %18 = OpConstant %15 0
56 %25 = OpTypeStruct %6
57 %26 = OpTypePointer Uniform %25
58 %27 = OpVariable %26 Uniform
59 %28 = OpTypePointer Uniform %6
60 %32 = OpTypeBool
61 %103 = OpConstantTrue %32
62 %34 = OpConstant %6 0.100000001
63 %48 = OpConstant %15 1
64 %50 = OpTypePointer Output %7
65 %51 = OpVariable %50 Output
66 %100 = OpTypePointer Function %6
67 %4 = OpFunction %2 None %3
68 %5 = OpLabel
69 %101 = OpVariable %100 Function
70 %102 = OpVariable %100 Function
71 OpBranch %19
72 %19 = OpLabel
73 %60 = OpPhi %7 %14 %5 %58 %20
74 %59 = OpPhi %15 %18 %5 %49 %20
75 %29 = OpAccessChain %28 %27 %18
76 %30 = OpLoad %6 %29
77 %31 = OpConvertFToS %15 %30
78 %33 = OpSLessThan %32 %59 %31
79 OpLoopMerge %21 %20 None
80 OpBranchConditional %33 %20 %21 1 2
81 %20 = OpLabel
82 %39 = OpCompositeExtract %6 %60 0
83 %40 = OpFAdd %6 %39 %34
84 %55 = OpCompositeInsert %7 %40 %60 0
85 %44 = OpCompositeExtract %6 %60 1
86 %45 = OpFSub %6 %44 %34
87 %58 = OpCompositeInsert %7 %45 %55 1
88 %49 = OpIAdd %15 %59 %48
89 OpBranch %19
90 %21 = OpLabel
91 OpStore %51 %60
92 OpSelectionMerge %105 None
93 OpBranchConditional %103 %104 %105
94 %104 = OpLabel
95 OpBranch %105
96 %105 = OpLabel
97 OpReturn
98 OpFunctionEnd
99 )";
100
101 const auto env = SPV_ENV_UNIVERSAL_1_5;
102 const auto consumer = nullptr;
103 const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
104 spvtools::ValidatorOptions validator_options;
105 ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options,
106 kConsoleMessageConsumer));
107 TransformationContext transformation_context(
108 MakeUnique<FactManager>(context.get()), validator_options);
109 // Tests OpBranchConditional instruction with weigths.
110 auto instruction_descriptor =
111 MakeInstructionDescriptor(33, SpvOpBranchConditional, 0);
112 auto transformation =
113 TransformationAdjustBranchWeights(instruction_descriptor, {0, 1});
114 ASSERT_TRUE(
115 transformation.IsApplicable(context.get(), transformation_context));
116
117 // Tests the two branch weights equal to 0.
118 instruction_descriptor =
119 MakeInstructionDescriptor(33, SpvOpBranchConditional, 0);
120 transformation =
121 TransformationAdjustBranchWeights(instruction_descriptor, {0, 0});
122 #ifndef NDEBUG
123 ASSERT_DEATH(
124 transformation.IsApplicable(context.get(), transformation_context),
125 "At least one weight must be non-zero");
126 #endif
127
128 // Tests 32-bit unsigned integer overflow.
129 instruction_descriptor =
130 MakeInstructionDescriptor(33, SpvOpBranchConditional, 0);
131 transformation = TransformationAdjustBranchWeights(instruction_descriptor,
132 {UINT32_MAX, 0});
133 ASSERT_TRUE(
134 transformation.IsApplicable(context.get(), transformation_context));
135
136 instruction_descriptor =
137 MakeInstructionDescriptor(33, SpvOpBranchConditional, 0);
138 transformation = TransformationAdjustBranchWeights(instruction_descriptor,
139 {1, UINT32_MAX});
140 #ifndef NDEBUG
141 ASSERT_DEATH(
142 transformation.IsApplicable(context.get(), transformation_context),
143 "The sum of the two weights must not be greater than UINT32_MAX");
144 #endif
145
146 // Tests OpBranchConditional instruction with no weights.
147 instruction_descriptor =
148 MakeInstructionDescriptor(21, SpvOpBranchConditional, 0);
149 transformation =
150 TransformationAdjustBranchWeights(instruction_descriptor, {0, 1});
151 ASSERT_TRUE(
152 transformation.IsApplicable(context.get(), transformation_context));
153
154 // Tests non-OpBranchConditional instructions.
155 instruction_descriptor = MakeInstructionDescriptor(2, SpvOpTypeVoid, 0);
156 transformation =
157 TransformationAdjustBranchWeights(instruction_descriptor, {5, 6});
158 ASSERT_FALSE(
159 transformation.IsApplicable(context.get(), transformation_context));
160
161 instruction_descriptor = MakeInstructionDescriptor(20, SpvOpLabel, 0);
162 transformation =
163 TransformationAdjustBranchWeights(instruction_descriptor, {1, 2});
164 ASSERT_FALSE(
165 transformation.IsApplicable(context.get(), transformation_context));
166
167 instruction_descriptor = MakeInstructionDescriptor(49, SpvOpIAdd, 0);
168 transformation =
169 TransformationAdjustBranchWeights(instruction_descriptor, {1, 2});
170 ASSERT_FALSE(
171 transformation.IsApplicable(context.get(), transformation_context));
172 }
173
TEST(TransformationAdjustBranchWeightsTest,ApplyTest)174 TEST(TransformationAdjustBranchWeightsTest, ApplyTest) {
175 std::string shader = R"(
176 OpCapability Shader
177 %1 = OpExtInstImport "GLSL.std.450"
178 OpMemoryModel Logical GLSL450
179 OpEntryPoint Fragment %4 "main" %51 %27
180 OpExecutionMode %4 OriginUpperLeft
181 OpSource ESSL 310
182 OpName %4 "main"
183 OpName %25 "buf"
184 OpMemberName %25 0 "value"
185 OpName %27 ""
186 OpName %51 "color"
187 OpMemberDecorate %25 0 Offset 0
188 OpDecorate %25 Block
189 OpDecorate %27 DescriptorSet 0
190 OpDecorate %27 Binding 0
191 OpDecorate %51 Location 0
192 %2 = OpTypeVoid
193 %3 = OpTypeFunction %2
194 %6 = OpTypeFloat 32
195 %7 = OpTypeVector %6 4
196 %150 = OpTypeVector %6 2
197 %10 = OpConstant %6 0.300000012
198 %11 = OpConstant %6 0.400000006
199 %12 = OpConstant %6 0.5
200 %13 = OpConstant %6 1
201 %14 = OpConstantComposite %7 %10 %11 %12 %13
202 %15 = OpTypeInt 32 1
203 %18 = OpConstant %15 0
204 %25 = OpTypeStruct %6
205 %26 = OpTypePointer Uniform %25
206 %27 = OpVariable %26 Uniform
207 %28 = OpTypePointer Uniform %6
208 %32 = OpTypeBool
209 %103 = OpConstantTrue %32
210 %34 = OpConstant %6 0.100000001
211 %48 = OpConstant %15 1
212 %50 = OpTypePointer Output %7
213 %51 = OpVariable %50 Output
214 %100 = OpTypePointer Function %6
215 %4 = OpFunction %2 None %3
216 %5 = OpLabel
217 %101 = OpVariable %100 Function
218 %102 = OpVariable %100 Function
219 OpBranch %19
220 %19 = OpLabel
221 %60 = OpPhi %7 %14 %5 %58 %20
222 %59 = OpPhi %15 %18 %5 %49 %20
223 %29 = OpAccessChain %28 %27 %18
224 %30 = OpLoad %6 %29
225 %31 = OpConvertFToS %15 %30
226 %33 = OpSLessThan %32 %59 %31
227 OpLoopMerge %21 %20 None
228 OpBranchConditional %33 %20 %21 1 2
229 %20 = OpLabel
230 %39 = OpCompositeExtract %6 %60 0
231 %40 = OpFAdd %6 %39 %34
232 %55 = OpCompositeInsert %7 %40 %60 0
233 %44 = OpCompositeExtract %6 %60 1
234 %45 = OpFSub %6 %44 %34
235 %58 = OpCompositeInsert %7 %45 %55 1
236 %49 = OpIAdd %15 %59 %48
237 OpBranch %19
238 %21 = OpLabel
239 OpStore %51 %60
240 OpSelectionMerge %105 None
241 OpBranchConditional %103 %104 %105
242 %104 = OpLabel
243 OpBranch %105
244 %105 = OpLabel
245 OpReturn
246 OpFunctionEnd
247 )";
248
249 const auto env = SPV_ENV_UNIVERSAL_1_5;
250 const auto consumer = nullptr;
251 const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
252 spvtools::ValidatorOptions validator_options;
253 ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options,
254 kConsoleMessageConsumer));
255 TransformationContext transformation_context(
256 MakeUnique<FactManager>(context.get()), validator_options);
257 auto instruction_descriptor =
258 MakeInstructionDescriptor(33, SpvOpBranchConditional, 0);
259 auto transformation =
260 TransformationAdjustBranchWeights(instruction_descriptor, {5, 6});
261 ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context);
262
263 instruction_descriptor =
264 MakeInstructionDescriptor(21, SpvOpBranchConditional, 0);
265 transformation =
266 TransformationAdjustBranchWeights(instruction_descriptor, {7, 8});
267 ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context);
268
269 std::string variant_shader = R"(
270 OpCapability Shader
271 %1 = OpExtInstImport "GLSL.std.450"
272 OpMemoryModel Logical GLSL450
273 OpEntryPoint Fragment %4 "main" %51 %27
274 OpExecutionMode %4 OriginUpperLeft
275 OpSource ESSL 310
276 OpName %4 "main"
277 OpName %25 "buf"
278 OpMemberName %25 0 "value"
279 OpName %27 ""
280 OpName %51 "color"
281 OpMemberDecorate %25 0 Offset 0
282 OpDecorate %25 Block
283 OpDecorate %27 DescriptorSet 0
284 OpDecorate %27 Binding 0
285 OpDecorate %51 Location 0
286 %2 = OpTypeVoid
287 %3 = OpTypeFunction %2
288 %6 = OpTypeFloat 32
289 %7 = OpTypeVector %6 4
290 %150 = OpTypeVector %6 2
291 %10 = OpConstant %6 0.300000012
292 %11 = OpConstant %6 0.400000006
293 %12 = OpConstant %6 0.5
294 %13 = OpConstant %6 1
295 %14 = OpConstantComposite %7 %10 %11 %12 %13
296 %15 = OpTypeInt 32 1
297 %18 = OpConstant %15 0
298 %25 = OpTypeStruct %6
299 %26 = OpTypePointer Uniform %25
300 %27 = OpVariable %26 Uniform
301 %28 = OpTypePointer Uniform %6
302 %32 = OpTypeBool
303 %103 = OpConstantTrue %32
304 %34 = OpConstant %6 0.100000001
305 %48 = OpConstant %15 1
306 %50 = OpTypePointer Output %7
307 %51 = OpVariable %50 Output
308 %100 = OpTypePointer Function %6
309 %4 = OpFunction %2 None %3
310 %5 = OpLabel
311 %101 = OpVariable %100 Function
312 %102 = OpVariable %100 Function
313 OpBranch %19
314 %19 = OpLabel
315 %60 = OpPhi %7 %14 %5 %58 %20
316 %59 = OpPhi %15 %18 %5 %49 %20
317 %29 = OpAccessChain %28 %27 %18
318 %30 = OpLoad %6 %29
319 %31 = OpConvertFToS %15 %30
320 %33 = OpSLessThan %32 %59 %31
321 OpLoopMerge %21 %20 None
322 OpBranchConditional %33 %20 %21 5 6
323 %20 = OpLabel
324 %39 = OpCompositeExtract %6 %60 0
325 %40 = OpFAdd %6 %39 %34
326 %55 = OpCompositeInsert %7 %40 %60 0
327 %44 = OpCompositeExtract %6 %60 1
328 %45 = OpFSub %6 %44 %34
329 %58 = OpCompositeInsert %7 %45 %55 1
330 %49 = OpIAdd %15 %59 %48
331 OpBranch %19
332 %21 = OpLabel
333 OpStore %51 %60
334 OpSelectionMerge %105 None
335 OpBranchConditional %103 %104 %105 7 8
336 %104 = OpLabel
337 OpBranch %105
338 %105 = OpLabel
339 OpReturn
340 OpFunctionEnd
341 )";
342
343 ASSERT_TRUE(IsEqual(env, variant_shader, context.get()));
344 }
345
346 } // namespace
347 } // namespace fuzz
348 } // namespace spvtools
349