1 // Copyright (c) 2016 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include <string>
16 #include <utility>
17 #include <vector>
18
19 #include "gmock/gmock.h"
20 #include "gtest/gtest.h"
21 #include "spirv-tools/optimizer.hpp"
22 #include "spirv/1.1/spirv.h"
23
24 namespace spvtools {
25 namespace {
26
27 using ::testing::ContainerEq;
28 using ::testing::HasSubstr;
29
30 // Return a string that contains the minimum instructions needed to form
31 // a valid module. Other instructions can be appended to this string.
Header()32 std::string Header() {
33 return R"(OpCapability Shader
34 OpCapability Linkage
35 OpMemoryModel Logical GLSL450
36 )";
37 }
38
39 // When we assemble with a target environment of SPIR-V 1.1, we expect
40 // the following in the module header version word.
41 const uint32_t kExpectedSpvVersion = 0x10100;
42
TEST(CppInterface,SuccessfulRoundTrip)43 TEST(CppInterface, SuccessfulRoundTrip) {
44 const std::string input_text = "%2 = OpSizeOf %1 %3\n";
45 SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
46
47 std::vector<uint32_t> binary;
48 EXPECT_TRUE(t.Assemble(input_text, &binary));
49 EXPECT_TRUE(binary.size() > 5u);
50 EXPECT_EQ(SpvMagicNumber, binary[0]);
51 EXPECT_EQ(kExpectedSpvVersion, binary[1]);
52
53 // This cannot pass validation since %1 is not defined.
54 t.SetMessageConsumer([](spv_message_level_t level, const char* source,
55 const spv_position_t& position, const char* message) {
56 EXPECT_EQ(SPV_MSG_ERROR, level);
57 EXPECT_STREQ("input", source);
58 EXPECT_EQ(0u, position.line);
59 EXPECT_EQ(0u, position.column);
60 EXPECT_EQ(1u, position.index);
61 EXPECT_STREQ("ID 1[%1] has not been defined\n %2 = OpSizeOf %1 %3\n",
62 message);
63 });
64 EXPECT_FALSE(t.Validate(binary));
65
66 std::string output_text;
67 EXPECT_TRUE(t.Disassemble(binary, &output_text));
68 EXPECT_EQ(input_text, output_text);
69 }
70
TEST(CppInterface,AssembleEmptyModule)71 TEST(CppInterface, AssembleEmptyModule) {
72 std::vector<uint32_t> binary(10, 42);
73 SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
74 EXPECT_TRUE(t.Assemble("", &binary));
75 // We only have the header.
76 EXPECT_EQ(5u, binary.size());
77 EXPECT_EQ(SpvMagicNumber, binary[0]);
78 EXPECT_EQ(kExpectedSpvVersion, binary[1]);
79 }
80
TEST(CppInterface,AssembleOverloads)81 TEST(CppInterface, AssembleOverloads) {
82 const std::string input_text = "%2 = OpSizeOf %1 %3\n";
83 SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
84 {
85 std::vector<uint32_t> binary;
86 EXPECT_TRUE(t.Assemble(input_text, &binary));
87 EXPECT_TRUE(binary.size() > 5u);
88 EXPECT_EQ(SpvMagicNumber, binary[0]);
89 EXPECT_EQ(kExpectedSpvVersion, binary[1]);
90 }
91 {
92 std::vector<uint32_t> binary;
93 EXPECT_TRUE(t.Assemble(input_text.data(), input_text.size(), &binary));
94 EXPECT_TRUE(binary.size() > 5u);
95 EXPECT_EQ(SpvMagicNumber, binary[0]);
96 EXPECT_EQ(kExpectedSpvVersion, binary[1]);
97 }
98 { // Ignore the last newline.
99 std::vector<uint32_t> binary;
100 EXPECT_TRUE(t.Assemble(input_text.data(), input_text.size() - 1, &binary));
101 EXPECT_TRUE(binary.size() > 5u);
102 EXPECT_EQ(SpvMagicNumber, binary[0]);
103 EXPECT_EQ(kExpectedSpvVersion, binary[1]);
104 }
105 }
106
TEST(CppInterface,DisassembleEmptyModule)107 TEST(CppInterface, DisassembleEmptyModule) {
108 std::string text(10, 'x');
109 SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
110 int invocation_count = 0;
111 t.SetMessageConsumer(
112 [&invocation_count](spv_message_level_t level, const char* source,
113 const spv_position_t& position, const char* message) {
114 ++invocation_count;
115 EXPECT_EQ(SPV_MSG_ERROR, level);
116 EXPECT_STREQ("input", source);
117 EXPECT_EQ(0u, position.line);
118 EXPECT_EQ(0u, position.column);
119 EXPECT_EQ(0u, position.index);
120 EXPECT_STREQ("Missing module.", message);
121 });
122 EXPECT_FALSE(t.Disassemble({}, &text));
123 EXPECT_EQ("xxxxxxxxxx", text); // The original string is unmodified.
124 EXPECT_EQ(1, invocation_count);
125 }
126
TEST(CppInterface,DisassembleOverloads)127 TEST(CppInterface, DisassembleOverloads) {
128 const std::string input_text = "%2 = OpSizeOf %1 %3\n";
129 SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
130
131 std::vector<uint32_t> binary;
132 EXPECT_TRUE(t.Assemble(input_text, &binary));
133
134 {
135 std::string output_text;
136 EXPECT_TRUE(t.Disassemble(binary, &output_text));
137 EXPECT_EQ(input_text, output_text);
138 }
139 {
140 std::string output_text;
141 EXPECT_TRUE(t.Disassemble(binary.data(), binary.size(), &output_text));
142 EXPECT_EQ(input_text, output_text);
143 }
144 }
145
TEST(CppInterface,SuccessfulValidation)146 TEST(CppInterface, SuccessfulValidation) {
147 SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
148 int invocation_count = 0;
149 t.SetMessageConsumer([&invocation_count](spv_message_level_t, const char*,
150 const spv_position_t&, const char*) {
151 ++invocation_count;
152 });
153
154 std::vector<uint32_t> binary;
155 EXPECT_TRUE(t.Assemble(Header(), &binary));
156 EXPECT_TRUE(t.Validate(binary));
157 EXPECT_EQ(0, invocation_count);
158 }
159
TEST(CppInterface,ValidateOverloads)160 TEST(CppInterface, ValidateOverloads) {
161 SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
162 std::vector<uint32_t> binary;
163 EXPECT_TRUE(t.Assemble(Header(), &binary));
164
165 { EXPECT_TRUE(t.Validate(binary)); }
166 { EXPECT_TRUE(t.Validate(binary.data(), binary.size())); }
167 }
168
TEST(CppInterface,ValidateEmptyModule)169 TEST(CppInterface, ValidateEmptyModule) {
170 SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
171 int invocation_count = 0;
172 t.SetMessageConsumer(
173 [&invocation_count](spv_message_level_t level, const char* source,
174 const spv_position_t& position, const char* message) {
175 ++invocation_count;
176 EXPECT_EQ(SPV_MSG_ERROR, level);
177 EXPECT_STREQ("input", source);
178 EXPECT_EQ(0u, position.line);
179 EXPECT_EQ(0u, position.column);
180 EXPECT_EQ(0u, position.index);
181 EXPECT_STREQ("Invalid SPIR-V magic number.", message);
182 });
183 EXPECT_FALSE(t.Validate({}));
184 EXPECT_EQ(1, invocation_count);
185 }
186
187 // Returns the assembly for a SPIR-V module with a struct declaration
188 // with the given number of members.
MakeModuleHavingStruct(int num_members)189 std::string MakeModuleHavingStruct(int num_members) {
190 std::stringstream os;
191 os << Header();
192 os << R"(%1 = OpTypeInt 32 0
193 %2 = OpTypeStruct)";
194 for (int i = 0; i < num_members; i++) os << " %1";
195 return os.str();
196 }
197
TEST(CppInterface,ValidateWithOptionsPass)198 TEST(CppInterface, ValidateWithOptionsPass) {
199 SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
200 std::vector<uint32_t> binary;
201 EXPECT_TRUE(t.Assemble(MakeModuleHavingStruct(10), &binary));
202 const ValidatorOptions opts;
203
204 EXPECT_TRUE(t.Validate(binary.data(), binary.size(), opts));
205 }
206
TEST(CppInterface,ValidateWithOptionsFail)207 TEST(CppInterface, ValidateWithOptionsFail) {
208 SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
209 std::vector<uint32_t> binary;
210 EXPECT_TRUE(t.Assemble(MakeModuleHavingStruct(10), &binary));
211 ValidatorOptions opts;
212 opts.SetUniversalLimit(spv_validator_limit_max_struct_members, 9);
213 std::stringstream os;
214 t.SetMessageConsumer([&os](spv_message_level_t, const char*,
215 const spv_position_t&,
216 const char* message) { os << message; });
217
218 EXPECT_FALSE(t.Validate(binary.data(), binary.size(), opts));
219 EXPECT_THAT(
220 os.str(),
221 HasSubstr(
222 "Number of OpTypeStruct members (10) has exceeded the limit (9)"));
223 }
224
225 // Checks that after running the given optimizer |opt| on the given |original|
226 // source code, we can get the given |optimized| source code.
CheckOptimization(const std::string & original,const std::string & optimized,const Optimizer & opt)227 void CheckOptimization(const std::string& original,
228 const std::string& optimized, const Optimizer& opt) {
229 SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
230 std::vector<uint32_t> original_binary;
231 ASSERT_TRUE(t.Assemble(original, &original_binary));
232
233 std::vector<uint32_t> optimized_binary;
234 EXPECT_TRUE(opt.Run(original_binary.data(), original_binary.size(),
235 &optimized_binary));
236
237 std::string optimized_text;
238 EXPECT_TRUE(t.Disassemble(optimized_binary, &optimized_text));
239 EXPECT_EQ(optimized, optimized_text);
240 }
241
TEST(CppInterface,OptimizeEmptyModule)242 TEST(CppInterface, OptimizeEmptyModule) {
243 SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
244 std::vector<uint32_t> binary;
245 EXPECT_TRUE(t.Assemble("", &binary));
246
247 Optimizer o(SPV_ENV_UNIVERSAL_1_1);
248 o.RegisterPass(CreateStripDebugInfoPass());
249
250 // Fails to validate.
251 EXPECT_FALSE(o.Run(binary.data(), binary.size(), &binary));
252 }
253
TEST(CppInterface,OptimizeModifiedModule)254 TEST(CppInterface, OptimizeModifiedModule) {
255 Optimizer o(SPV_ENV_UNIVERSAL_1_1);
256 o.RegisterPass(CreateStripDebugInfoPass());
257 CheckOptimization(Header() + "OpSource GLSL 450", Header(), o);
258 }
259
TEST(CppInterface,OptimizeMulitplePasses)260 TEST(CppInterface, OptimizeMulitplePasses) {
261 std::string original_text = Header() +
262 "OpSource GLSL 450 "
263 "OpDecorate %true SpecId 1 "
264 "%bool = OpTypeBool "
265 "%true = OpSpecConstantTrue %bool";
266
267 Optimizer o(SPV_ENV_UNIVERSAL_1_1);
268 o.RegisterPass(CreateStripDebugInfoPass())
269 .RegisterPass(CreateFreezeSpecConstantValuePass());
270
271 std::string expected_text = Header() +
272 "%bool = OpTypeBool\n"
273 "%true = OpConstantTrue %bool\n";
274
275 CheckOptimization(original_text, expected_text, o);
276 }
277
TEST(CppInterface,OptimizeDoNothingWithPassToken)278 TEST(CppInterface, OptimizeDoNothingWithPassToken) {
279 CreateFreezeSpecConstantValuePass();
280 auto token = CreateUnifyConstantPass();
281 }
282
TEST(CppInterface,OptimizeReassignPassToken)283 TEST(CppInterface, OptimizeReassignPassToken) {
284 auto token = CreateNullPass();
285 token = CreateStripDebugInfoPass();
286
287 CheckOptimization(
288 Header() + "OpSource GLSL 450", Header(),
289 Optimizer(SPV_ENV_UNIVERSAL_1_1).RegisterPass(std::move(token)));
290 }
291
TEST(CppInterface,OptimizeMoveConstructPassToken)292 TEST(CppInterface, OptimizeMoveConstructPassToken) {
293 auto token1 = CreateStripDebugInfoPass();
294 Optimizer::PassToken token2(std::move(token1));
295
296 CheckOptimization(
297 Header() + "OpSource GLSL 450", Header(),
298 Optimizer(SPV_ENV_UNIVERSAL_1_1).RegisterPass(std::move(token2)));
299 }
300
TEST(CppInterface,OptimizeMoveAssignPassToken)301 TEST(CppInterface, OptimizeMoveAssignPassToken) {
302 auto token1 = CreateStripDebugInfoPass();
303 auto token2 = CreateNullPass();
304 token2 = std::move(token1);
305
306 CheckOptimization(
307 Header() + "OpSource GLSL 450", Header(),
308 Optimizer(SPV_ENV_UNIVERSAL_1_1).RegisterPass(std::move(token2)));
309 }
310
TEST(CppInterface,OptimizeSameAddressForOriginalOptimizedBinary)311 TEST(CppInterface, OptimizeSameAddressForOriginalOptimizedBinary) {
312 SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
313 std::vector<uint32_t> binary;
314 ASSERT_TRUE(t.Assemble(Header() + "OpSource GLSL 450", &binary));
315
316 EXPECT_TRUE(Optimizer(SPV_ENV_UNIVERSAL_1_1)
317 .RegisterPass(CreateStripDebugInfoPass())
318 .Run(binary.data(), binary.size(), &binary));
319
320 std::string optimized_text;
321 EXPECT_TRUE(t.Disassemble(binary, &optimized_text));
322 EXPECT_EQ(Header(), optimized_text);
323 }
324
325 // TODO(antiagainst): tests for SetMessageConsumer().
326
327 } // namespace
328 } // namespace spvtools
329