1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef ANDROID_FRAMEWORK_ML_NN_RUNTIME_TEST_FUZZING_RANDOM_GRAPH_GENERATOR_UTILS_H
18 #define ANDROID_FRAMEWORK_ML_NN_RUNTIME_TEST_FUZZING_RANDOM_GRAPH_GENERATOR_UTILS_H
19
20 #include <chrono>
21 #include <fstream>
22 #include <memory>
23 #include <random>
24 #include <sstream>
25 #include <string>
26
27 #include "RandomGraphGenerator.h"
28 #include "RandomVariable.h"
29 #include "TestNeuralNetworksWrapper.h"
30
31 namespace android {
32 namespace nn {
33 namespace fuzzing_test {
34
35 #define NN_FUZZER_LOG_INIT(filename) Logger::get()->init((filename))
36 #define NN_FUZZER_LOG_CLOSE Logger::get()->close()
37 #define NN_FUZZER_LOG \
38 if (!Logger::get()->enabled()) \
39 ; \
40 else \
41 LoggerStream(false) << alignedString(__FUNCTION__, 20)
42 #define NN_FUZZER_CHECK(condition) \
43 if ((condition)) \
44 ; \
45 else \
46 LoggerStream(true) << alignedString(__FUNCTION__, 20) << "Check failed " << #condition \
47 << ": "
48
49 // A Singleton manages the global configurations of logging.
50 class Logger {
51 public:
get()52 static Logger* get() {
53 static Logger instance;
54 return &instance;
55 }
init(const std::string & filename)56 void init(const std::string& filename) {
57 os.open(filename);
58 mStart = std::chrono::high_resolution_clock::now();
59 }
enabled()60 bool enabled() { return os.is_open(); }
close()61 void close() {
62 if (os.is_open()) os.close();
63 }
log(const std::string & str)64 void log(const std::string& str) {
65 if (os.is_open()) os << getElapsedTime() << str << std::flush;
66 }
67
68 private:
69 Logger() = default;
70 Logger(const Logger&) = delete;
71 Logger& operator=(const Logger&) = delete;
72 std::string getElapsedTime();
73 std::ofstream os;
74 std::chrono::time_point<std::chrono::high_resolution_clock> mStart;
75 };
76
77 // Controls logging of a single line.
78 class LoggerStream {
79 public:
LoggerStream(bool abortAfterLog)80 LoggerStream(bool abortAfterLog) : mAbortAfterLog(abortAfterLog) {}
~LoggerStream()81 ~LoggerStream() {
82 Logger::get()->log(ss.str() + '\n');
83 if (mAbortAfterLog) {
84 std::cout << ss.str() << std::endl;
85 abort();
86 }
87 }
88
89 template <typename T>
90 LoggerStream& operator<<(const T& str) {
91 ss << str;
92 return *this;
93 }
94
95 private:
96 LoggerStream(const LoggerStream&) = delete;
97 LoggerStream& operator=(const LoggerStream&) = delete;
98 std::stringstream ss;
99 bool mAbortAfterLog;
100 };
101
102 template <typename T>
toString(const T & obj)103 inline std::string toString(const T& obj) {
104 return std::to_string(obj);
105 }
106
107 template <typename T>
joinStr(const std::string & joint,const std::vector<T> & items)108 inline std::string joinStr(const std::string& joint, const std::vector<T>& items) {
109 std::stringstream ss;
110 for (uint32_t i = 0; i < items.size(); i++) {
111 if (i == 0) {
112 ss << toString(items[i]);
113 } else {
114 ss << joint << toString(items[i]);
115 }
116 }
117 return ss.str();
118 }
119
120 template <typename T, class Function>
joinStr(const std::string & joint,const std::vector<T> & items,Function str)121 inline std::string joinStr(const std::string& joint, const std::vector<T>& items, Function str) {
122 std::stringstream ss;
123 for (uint32_t i = 0; i < items.size(); i++) {
124 if (i != 0) ss << joint;
125 ss << str(items[i]);
126 }
127 return ss.str();
128 }
129
130 template <typename T>
joinStr(const std::string & joint,int limit,const std::vector<T> & items)131 inline std::string joinStr(const std::string& joint, int limit, const std::vector<T>& items) {
132 if (items.size() > static_cast<size_t>(limit)) {
133 std::vector<T> topMax(items.begin(), items.begin() + limit);
134 return joinStr(joint, topMax) + ", (" + toString(items.size() - limit) + " ommited), " +
135 toString(items.back());
136 } else {
137 return joinStr(joint, items);
138 }
139 }
140
141 // TODO: Currently only 1.0 operations and operand types.
142 static const char* kOperationNames[] = {
143 "ADD",
144 "AVERAGE_POOL_2D",
145 "CONCATENATION",
146 "CONV_2D",
147 "DEPTHWISE_CONV_2D",
148 "DEPTH_TO_SPACE",
149 "DEQUANTIZE",
150 "EMBEDDING_LOOKUP",
151 "FLOOR",
152 "FULLY_CONNECTED",
153 "HASHTABLE_LOOKUP",
154 "L2_NORMALIZATION",
155 "L2_POOL",
156 "LOCAL_RESPONSE_NORMALIZATION",
157 "LOGISTIC",
158 "LSH_PROJECTION",
159 "LSTM",
160 "MAX_POOL_2D",
161 "MUL",
162 "RELU",
163 "RELU1",
164 "RELU6",
165 "RESHAPE",
166 "RESIZE_BILINEAR",
167 "RNN",
168 "SOFTMAX",
169 "SPACE_TO_DEPTH",
170 "SVDF",
171 "TANH",
172 "BATCH_TO_SPACE_ND",
173 "DIV",
174 "MEAN",
175 "PAD",
176 "SPACE_TO_BATCH_ND",
177 "SQUEEZE",
178 "STRIDED_SLICE",
179 "SUB",
180 "TRANSPOSE",
181 "ABS",
182 "ARGMAX",
183 "ARGMIN",
184 "AXIS_ALIGNED_BBOX_TRANSFORM",
185 "BIDIRECTIONAL_SEQUENCE_LSTM",
186 "BIDIRECTIONAL_SEQUENCE_RNN",
187 "BOX_WITH_NMS_LIMIT",
188 "CAST",
189 "CHANNEL_SHUFFLE",
190 "DETECTION_POSTPROCESSING",
191 "EQUAL",
192 "EXP",
193 "EXPAND_DIMS",
194 "GATHER",
195 "GENERATE_PROPOSALS",
196 "GREATER",
197 "GREATER_EQUAL",
198 "GROUPED_CONV_2D",
199 "HEATMAP_MAX_KEYPOINT",
200 "INSTANCE_NORMALIZATION",
201 "LESS",
202 "LESS_EQUAL",
203 "LOG",
204 "LOGICAL_AND",
205 "LOGICAL_NOT",
206 "LOGICAL_OR",
207 "LOG_SOFTMAX",
208 "MAXIMUM",
209 "MINIMUM",
210 "NEG",
211 "NOT_EQUAL",
212 "PAD_V2",
213 "POW",
214 "PRELU",
215 "QUANTIZE",
216 "QUANTIZED_16BIT_LSTM",
217 "RANDOM_MULTINOMIAL",
218 "REDUCE_ALL",
219 "REDUCE_ANY",
220 "REDUCE_MAX",
221 "REDUCE_MIN",
222 "REDUCE_PROD",
223 "REDUCE_SUM",
224 "ROI_ALIGN",
225 "ROI_POOLING",
226 "RSQRT",
227 "SELECT",
228 "SIN",
229 "SLICE",
230 "SPLIT",
231 "SQRT",
232 "TILE",
233 "TOPK_V2",
234 "TRANSPOSE_CONV_2D",
235 "UNIDIRECTIONAL_SEQUENCE_LSTM",
236 "UNIDIRECTIONAL_SEQUENCE_RNN",
237 "RESIZE_NEAREST_NEIGHBOR",
238 };
239
240 static const char* kTypeNames[] = {
241 "FLOAT32",
242 "INT32",
243 "UINT32",
244 "TENSOR_FLOAT32",
245 "TENSOR_INT32",
246 "TENSOR_QUANT8_ASYMM",
247 "BOOL",
248 "TENSOR_QUANT16_SYMM",
249 "TENSOR_FLOAT16",
250 "TENSOR_BOOL8",
251 "FLOAT16",
252 "TENSOR_QUANT8_SYMM_PER_CHANNEL",
253 "TENSOR_QUANT16_ASYMM",
254 "TENSOR_QUANT8_SYMM",
255 };
256
257 static const char* kLifeTimeNames[6] = {
258 "TEMPORARY_VARIABLE", "MODEL_INPUT", "MODEL_OUTPUT",
259 "CONSTANT_COPY", "CONSTANT_REFERENCE", "NO_VALUE",
260 };
261
262 static const bool kScalarDataType[]{
263 true, // ANEURALNETWORKS_FLOAT32
264 true, // ANEURALNETWORKS_INT32
265 true, // ANEURALNETWORKS_UINT32
266 false, // ANEURALNETWORKS_TENSOR_FLOAT32
267 false, // ANEURALNETWORKS_TENSOR_INT32
268 false, // ANEURALNETWORKS_TENSOR_SYMMETRICAL_QUANT8
269 true, // ANEURALNETWORKS_BOOL
270 false, // ANEURALNETWORKS_TENSOR_QUANT16_SYMM
271 false, // ANEURALNETWORKS_TENSOR_FLOAT16
272 false, // ANEURALNETWORKS_TENSOR_BOOL8
273 true, // ANEURALNETWORKS_FLOAT16
274 false, // ANEURALNETWORKS_TENSOR_QUANT8_SYMM_PER_CHANNEL
275 false, // ANEURALNETWORKS_TENSOR_QUANT16_ASYMM
276 false, // ANEURALNETWORKS_TENSOR_QUANT8_SYMM
277 };
278
279 static const uint32_t kSizeOfDataType[]{
280 4, // ANEURALNETWORKS_FLOAT32
281 4, // ANEURALNETWORKS_INT32
282 4, // ANEURALNETWORKS_UINT32
283 4, // ANEURALNETWORKS_TENSOR_FLOAT32
284 4, // ANEURALNETWORKS_TENSOR_INT32
285 1, // ANEURALNETWORKS_TENSOR_SYMMETRICAL_QUANT8
286 1, // ANEURALNETWORKS_BOOL
287 2, // ANEURALNETWORKS_TENSOR_QUANT16_SYMM
288 2, // ANEURALNETWORKS_TENSOR_FLOAT16
289 1, // ANEURALNETWORKS_TENSOR_BOOL8
290 2, // ANEURALNETWORKS_FLOAT16
291 1, // ANEURALNETWORKS_TENSOR_QUANT8_SYMM_PER_CHANNEL
292 2, // ANEURALNETWORKS_TENSOR_QUANT16_ASYMM
293 1, // ANEURALNETWORKS_TENSOR_QUANT8_SYMM
294 };
295
296 template <>
297 inline std::string toString<RandomVariableType>(const RandomVariableType& type) {
298 static const std::string typeNames[] = {"FREE", "CONST", "OP"};
299 return typeNames[static_cast<int>(type)];
300 }
301
alignedString(std::string str,int width)302 inline std::string alignedString(std::string str, int width) {
303 str.push_back(':');
304 str.resize(width + 1, ' ');
305 return str;
306 }
307
308 template <>
309 inline std::string toString<RandomVariableRange>(const RandomVariableRange& range) {
310 return "[" + joinStr(", ", 20, range.getChoices()) + "]";
311 }
312
313 template <>
314 inline std::string toString<RandomOperandType>(const RandomOperandType& type) {
315 static const std::string typeNames[] = {"Input", "Output", "Internal", "Parameter"};
316 return typeNames[static_cast<int>(type)];
317 }
318
319 template <>
320 inline std::string toString<RandomVariableNode>(const RandomVariableNode& var) {
321 std::stringstream ss;
322 ss << "var" << var->index << " = ";
323 switch (var->type) {
324 case RandomVariableType::FREE:
325 ss << "FREE " << toString(var->range);
326 break;
327 case RandomVariableType::CONST:
328 ss << "CONST " << toString(var->value);
329 break;
330 case RandomVariableType::OP:
331 ss << "var" << var->parent1->index << " " << var->op->getName();
332 if (var->parent2 != nullptr) ss << " var" << var->parent2->index;
333 ss << ", " << toString(var->range);
334 break;
335 default:
336 NN_FUZZER_CHECK(false);
337 }
338 ss << ", timestamp = " << var->timestamp;
339 return ss.str();
340 }
341
342 template <>
343 inline std::string toString<Type>(const Type& type) {
344 return kTypeNames[static_cast<int32_t>(type)];
345 }
346
347 template <>
348 inline std::string toString<RandomVariable>(const RandomVariable& var) {
349 return "var" + std::to_string(var.get()->index);
350 }
351
352 template <>
353 inline std::string toString<RandomOperand>(const RandomOperand& op) {
354 return toString(op.type) + ", dimension = [" +
355 joinStr(", ", op.dimensions,
356 [](const RandomVariable& var) { return std::to_string(var.getValue()); }) +
357 "], scale = " + toString(op.scale) + " , zero_point = " + toString(op.zeroPoint);
358 }
359
360 // This class is a workaround for two issues our code relies on:
361 // 1. sizeof(bool) is implementation defined.
362 // 2. vector<bool> does not allow direct pointer access via the data() method.
363 class bool8 {
364 public:
bool8()365 bool8() : mValue() {}
bool8(bool value)366 /* implicit */ bool8(bool value) : mValue(value) {}
367 inline operator bool() const { return mValue != 0; }
368
369 private:
370 uint8_t mValue;
371 };
372 static_assert(sizeof(bool8) == 1, "size of bool8 must be 8 bits");
373
374 // Dump the random graph to a spec file.
375 class SpecWriter {
376 public:
377 SpecWriter(std::string filename, std::string testname = "");
isOpen()378 bool isOpen() { return os.is_open(); }
379 void dump(const std::vector<RandomOperation>& operations,
380 const std::vector<std::shared_ptr<RandomOperand>>& operands);
381
382 private:
383 void dump(test_wrapper::Type type, const uint8_t* buffer, uint32_t length);
384 void dump(const std::vector<RandomVariable>& dimensions);
385 void dump(const std::shared_ptr<RandomOperand>& op);
386 void dump(const RandomOperation& op);
387
388 template <typename T>
dump(const T * buffer,uint32_t length)389 void dump(const T* buffer, uint32_t length) {
390 for (uint32_t i = 0; i < length; i++) {
391 if (i != 0) os << ", ";
392 if constexpr (std::is_integral<T>::value) {
393 os << static_cast<int>(buffer[i]);
394 } else if constexpr (std::is_same<T, _Float16>::value) {
395 os << static_cast<float>(buffer[i]);
396 } else if constexpr (std::is_same<T, bool8>::value) {
397 os << (buffer[i] ? "True" : "False");
398 } else {
399 os << buffer[i];
400 }
401 }
402 }
403
404 std::ofstream os;
405 };
406
407 struct RandomNumberGenerator {
408 static std::mt19937 generator;
409 };
410
getBernoulli(double p)411 inline bool getBernoulli(double p) {
412 std::bernoulli_distribution dis(p);
413 return dis(RandomNumberGenerator::generator);
414 }
415
416 // std::is_floating_point_v<_Float16> evaluates to true in CTS build target but false in
417 // NeuralNetworksTest_static, so we define getUniform<_Float16> explicitly here if not CTS.
418 #ifdef NNTEST_CTS
419 #define NN_IS_FLOAT(T) std::is_floating_point_v<T>
420 #else
421 #define NN_IS_FLOAT(T) std::is_floating_point_v<T> || std::is_same_v<T, _Float16>
422 #endif
423
424 // getUniform for floating point values operates on a open interval (lower, upper).
425 // This is important for generating a scale that is greater than but not equal to a lower bound.
426 template <typename T>
getUniform(T lower,T upper)427 inline std::enable_if_t<NN_IS_FLOAT(T), T> getUniform(T lower, T upper) {
428 float nextLower = std::nextafter(static_cast<float>(lower), std::numeric_limits<float>::max());
429 std::uniform_real_distribution<float> dis(nextLower, upper);
430 return dis(RandomNumberGenerator::generator);
431 }
432
433 // getUniform for integers operates on a closed interval [lower, upper].
434 // This is important that 255 should be included as a valid candidate for QUANT8_ASYMM values.
435 template <typename T>
getUniform(T lower,T upper)436 inline std::enable_if_t<std::is_integral_v<T>, T> getUniform(T lower, T upper) {
437 std::uniform_int_distribution<T> dis(lower, upper);
438 return dis(RandomNumberGenerator::generator);
439 }
440
441 template <typename T>
getRandomChoice(const std::vector<T> & choices)442 inline const T& getRandomChoice(const std::vector<T>& choices) {
443 NN_FUZZER_CHECK(!choices.empty()) << "Empty choices!";
444 std::uniform_int_distribution<size_t> dis(0, choices.size() - 1);
445 size_t i = dis(RandomNumberGenerator::generator);
446 return choices[i];
447 }
448
449 template <typename T>
randomShuffle(std::vector<T> * vec)450 inline void randomShuffle(std::vector<T>* vec) {
451 std::shuffle(vec->begin(), vec->end(), RandomNumberGenerator::generator);
452 }
453
454 } // namespace fuzzing_test
455 } // namespace nn
456 } // namespace android
457
458 #endif // ANDROID_FRAMEWORK_ML_NN_RUNTIME_TEST_FUZZING_RANDOM_GRAPH_GENERATOR_UTILS_H
459