1 // Copyright (c) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/fuzz/shrinker.h"
16
17 #include <sstream>
18
19 #include "source/fuzz/added_function_reducer.h"
20 #include "source/fuzz/pseudo_random_generator.h"
21 #include "source/fuzz/replayer.h"
22 #include "source/opt/build_module.h"
23 #include "source/opt/ir_context.h"
24 #include "source/spirv_fuzzer_options.h"
25 #include "source/util/make_unique.h"
26
27 namespace spvtools {
28 namespace fuzz {
29
30 namespace {
31
32 // A helper to get the size of a protobuf transformation sequence in a less
33 // verbose manner.
NumRemainingTransformations(const protobufs::TransformationSequence & transformation_sequence)34 uint32_t NumRemainingTransformations(
35 const protobufs::TransformationSequence& transformation_sequence) {
36 return static_cast<uint32_t>(transformation_sequence.transformation_size());
37 }
38
39 // A helper to return a transformation sequence identical to |transformations|,
40 // except that a chunk of size |chunk_size| starting from |chunk_index| x
41 // |chunk_size| is removed (or as many transformations as available if the whole
42 // chunk is not).
RemoveChunk(const protobufs::TransformationSequence & transformations,uint32_t chunk_index,uint32_t chunk_size)43 protobufs::TransformationSequence RemoveChunk(
44 const protobufs::TransformationSequence& transformations,
45 uint32_t chunk_index, uint32_t chunk_size) {
46 uint32_t lower = chunk_index * chunk_size;
47 uint32_t upper = std::min((chunk_index + 1) * chunk_size,
48 NumRemainingTransformations(transformations));
49 assert(lower < upper);
50 assert(upper <= NumRemainingTransformations(transformations));
51 protobufs::TransformationSequence result;
52 for (uint32_t j = 0; j < NumRemainingTransformations(transformations); j++) {
53 if (j >= lower && j < upper) {
54 continue;
55 }
56 protobufs::Transformation transformation =
57 transformations.transformation()[j];
58 *result.mutable_transformation()->Add() = transformation;
59 }
60 return result;
61 }
62
63 } // namespace
64
Shrinker(spv_target_env target_env,MessageConsumer consumer,const std::vector<uint32_t> & binary_in,const protobufs::FactSequence & initial_facts,const protobufs::TransformationSequence & transformation_sequence_in,const InterestingnessFunction & interestingness_function,uint32_t step_limit,bool validate_during_replay,spv_validator_options validator_options)65 Shrinker::Shrinker(
66 spv_target_env target_env, MessageConsumer consumer,
67 const std::vector<uint32_t>& binary_in,
68 const protobufs::FactSequence& initial_facts,
69 const protobufs::TransformationSequence& transformation_sequence_in,
70 const InterestingnessFunction& interestingness_function,
71 uint32_t step_limit, bool validate_during_replay,
72 spv_validator_options validator_options)
73 : target_env_(target_env),
74 consumer_(std::move(consumer)),
75 binary_in_(binary_in),
76 initial_facts_(initial_facts),
77 transformation_sequence_in_(transformation_sequence_in),
78 interestingness_function_(interestingness_function),
79 step_limit_(step_limit),
80 validate_during_replay_(validate_during_replay),
81 validator_options_(validator_options) {}
82
83 Shrinker::~Shrinker() = default;
84
Run()85 Shrinker::ShrinkerResult Shrinker::Run() {
86 // Check compatibility between the library version being linked with and the
87 // header files being used.
88 GOOGLE_PROTOBUF_VERIFY_VERSION;
89
90 SpirvTools tools(target_env_);
91 if (!tools.IsValid()) {
92 consumer_(SPV_MSG_ERROR, nullptr, {},
93 "Failed to create SPIRV-Tools interface; stopping.");
94 return {Shrinker::ShrinkerResultStatus::kFailedToCreateSpirvToolsInterface,
95 std::vector<uint32_t>(), protobufs::TransformationSequence()};
96 }
97
98 // Initial binary should be valid.
99 if (!tools.Validate(&binary_in_[0], binary_in_.size(), validator_options_)) {
100 consumer_(SPV_MSG_INFO, nullptr, {},
101 "Initial binary is invalid; stopping.");
102 return {Shrinker::ShrinkerResultStatus::kInitialBinaryInvalid,
103 std::vector<uint32_t>(), protobufs::TransformationSequence()};
104 }
105
106 // Run a replay of the initial transformation sequence to check that it
107 // succeeds.
108 auto initial_replay_result =
109 Replayer(target_env_, consumer_, binary_in_, initial_facts_,
110 transformation_sequence_in_,
111 static_cast<uint32_t>(
112 transformation_sequence_in_.transformation_size()),
113 validate_during_replay_, validator_options_)
114 .Run();
115 if (initial_replay_result.status !=
116 Replayer::ReplayerResultStatus::kComplete) {
117 return {ShrinkerResultStatus::kReplayFailed, std::vector<uint32_t>(),
118 protobufs::TransformationSequence()};
119 }
120 // Get the binary that results from running these transformations, and the
121 // subsequence of the initial transformations that actually apply (in
122 // principle this could be a strict subsequence).
123 std::vector<uint32_t> current_best_binary;
124 initial_replay_result.transformed_module->module()->ToBinary(
125 ¤t_best_binary, false);
126 protobufs::TransformationSequence current_best_transformations =
127 std::move(initial_replay_result.applied_transformations);
128
129 // Check that the binary produced by applying the initial transformations is
130 // indeed interesting.
131 if (!interestingness_function_(current_best_binary, 0)) {
132 consumer_(SPV_MSG_INFO, nullptr, {},
133 "Initial binary is not interesting; stopping.");
134 return {ShrinkerResultStatus::kInitialBinaryNotInteresting,
135 std::vector<uint32_t>(), protobufs::TransformationSequence()};
136 }
137
138 uint32_t attempt = 0; // Keeps track of the number of shrink attempts that
139 // have been tried, whether successful or not.
140
141 uint32_t chunk_size =
142 std::max(1u, NumRemainingTransformations(current_best_transformations) /
143 2); // The number of contiguous transformations that the
144 // shrinker will try to remove in one go; starts
145 // high and decreases during the shrinking process.
146
147 // Keep shrinking until we:
148 // - reach the step limit,
149 // - run out of transformations to remove, or
150 // - cannot make the chunk size any smaller.
151 while (attempt < step_limit_ &&
152 !current_best_transformations.transformation().empty() &&
153 chunk_size > 0) {
154 bool progress_this_round =
155 false; // Used to decide whether to make the chunk size with which we
156 // remove transformations smaller. If we managed to remove at
157 // least one chunk of transformations at a particular chunk
158 // size, we set this flag so that we do not yet decrease the
159 // chunk size.
160
161 assert(chunk_size <=
162 NumRemainingTransformations(current_best_transformations) &&
163 "Chunk size should never exceed the number of transformations that "
164 "remain.");
165
166 // The number of chunks is the ceiling of (#remaining_transformations /
167 // chunk_size).
168 const uint32_t num_chunks =
169 (NumRemainingTransformations(current_best_transformations) +
170 chunk_size - 1) /
171 chunk_size;
172 assert(num_chunks >= 1 && "There should be at least one chunk.");
173 assert(num_chunks * chunk_size >=
174 NumRemainingTransformations(current_best_transformations) &&
175 "All transformations should be in some chunk.");
176
177 // We go through the transformations in reverse, in chunks of size
178 // |chunk_size|, using |chunk_index| to track which chunk to try removing
179 // next. The loop exits early if we reach the shrinking step limit.
180 for (int chunk_index = num_chunks - 1;
181 attempt < step_limit_ && chunk_index >= 0; chunk_index--) {
182 // Remove a chunk of transformations according to the current index and
183 // chunk size.
184 auto transformations_with_chunk_removed =
185 RemoveChunk(current_best_transformations,
186 static_cast<uint32_t>(chunk_index), chunk_size);
187
188 // Replay the smaller sequence of transformations to get a next binary and
189 // transformation sequence. Note that the transformations arising from
190 // replay might be even smaller than the transformations with the chunk
191 // removed, because removing those transformations might make further
192 // transformations inapplicable.
193 auto replay_result =
194 Replayer(
195 target_env_, consumer_, binary_in_, initial_facts_,
196 transformations_with_chunk_removed,
197 static_cast<uint32_t>(
198 transformations_with_chunk_removed.transformation_size()),
199 validate_during_replay_, validator_options_)
200 .Run();
201 if (replay_result.status != Replayer::ReplayerResultStatus::kComplete) {
202 // Replay should not fail; if it does, we need to abort shrinking.
203 return {ShrinkerResultStatus::kReplayFailed, std::vector<uint32_t>(),
204 protobufs::TransformationSequence()};
205 }
206
207 assert(
208 NumRemainingTransformations(replay_result.applied_transformations) >=
209 chunk_index * chunk_size &&
210 "Removing this chunk of transformations should not have an effect "
211 "on earlier chunks.");
212
213 std::vector<uint32_t> transformed_binary;
214 replay_result.transformed_module->module()->ToBinary(&transformed_binary,
215 false);
216 if (interestingness_function_(transformed_binary, attempt)) {
217 // If the binary arising from the smaller transformation sequence is
218 // interesting, this becomes our current best binary and transformation
219 // sequence.
220 current_best_binary = std::move(transformed_binary);
221 current_best_transformations =
222 std::move(replay_result.applied_transformations);
223 progress_this_round = true;
224 }
225 // Either way, this was a shrink attempt, so increment our count of shrink
226 // attempts.
227 attempt++;
228 }
229 if (!progress_this_round) {
230 // If we didn't manage to remove any chunks at this chunk size, try a
231 // smaller chunk size.
232 chunk_size /= 2;
233 }
234 // Decrease the chunk size until it becomes no larger than the number of
235 // remaining transformations.
236 while (chunk_size >
237 NumRemainingTransformations(current_best_transformations)) {
238 chunk_size /= 2;
239 }
240 }
241
242 // We now use spirv-reduce to minimise the functions associated with any
243 // AddFunction transformations that remain.
244 //
245 // Consider every remaining transformation.
246 for (uint32_t transformation_index = 0;
247 attempt < step_limit_ &&
248 transformation_index <
249 static_cast<uint32_t>(
250 current_best_transformations.transformation_size());
251 transformation_index++) {
252 // Skip all transformations apart from TransformationAddFunction.
253 if (!current_best_transformations.transformation(transformation_index)
254 .has_add_function()) {
255 continue;
256 }
257 // Invoke spirv-reduce on the function encoded in this AddFunction
258 // transformation. The details of this are rather involved, and so are
259 // encapsulated in a separate class.
260 auto added_function_reducer_result =
261 AddedFunctionReducer(target_env_, consumer_, binary_in_, initial_facts_,
262 current_best_transformations, transformation_index,
263 interestingness_function_, validate_during_replay_,
264 validator_options_, step_limit_, attempt)
265 .Run();
266 // Reducing the added function should succeed. If it doesn't, we report
267 // a shrinking error.
268 if (added_function_reducer_result.status !=
269 AddedFunctionReducer::AddedFunctionReducerResultStatus::kComplete) {
270 return {ShrinkerResultStatus::kAddedFunctionReductionFailed,
271 std::vector<uint32_t>(), protobufs::TransformationSequence()};
272 }
273 assert(current_best_transformations.transformation_size() ==
274 added_function_reducer_result.applied_transformations
275 .transformation_size() &&
276 "The number of transformations should not have changed.");
277 current_best_binary =
278 std::move(added_function_reducer_result.transformed_binary);
279 current_best_transformations =
280 std::move(added_function_reducer_result.applied_transformations);
281 // The added function reducer reports how many reduction attempts
282 // spirv-reduce took when reducing the function. We regard each of these
283 // as a shrinker attempt.
284 attempt += added_function_reducer_result.num_reduction_attempts;
285 }
286
287 // Indicate whether shrinking completed or was truncated due to reaching the
288 // step limit.
289 //
290 // Either way, the output from the shrinker is the best binary we saw, and the
291 // transformations that led to it.
292 assert(attempt <= step_limit_);
293 if (attempt == step_limit_) {
294 std::stringstream strstream;
295 strstream << "Shrinking did not complete; step limit " << step_limit_
296 << " was reached.";
297 consumer_(SPV_MSG_WARNING, nullptr, {}, strstream.str().c_str());
298 return {Shrinker::ShrinkerResultStatus::kStepLimitReached,
299 std::move(current_best_binary),
300 std::move(current_best_transformations)};
301 }
302 return {Shrinker::ShrinkerResultStatus::kComplete,
303 std::move(current_best_binary),
304 std::move(current_best_transformations)};
305 }
306
GetIdBound(const std::vector<uint32_t> & binary) const307 uint32_t Shrinker::GetIdBound(const std::vector<uint32_t>& binary) const {
308 // Build the module from the input binary.
309 std::unique_ptr<opt::IRContext> ir_context =
310 BuildModule(target_env_, consumer_, binary.data(), binary.size());
311 assert(ir_context && "Error building module.");
312 return ir_context->module()->id_bound();
313 }
314
315 } // namespace fuzz
316 } // namespace spvtools
317