• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_FRAMEWORK_ML_NN_RUNTIME_TEST_FUZZING_RANDOM_VARIABLE_H
18 #define ANDROID_FRAMEWORK_ML_NN_RUNTIME_TEST_FUZZING_RANDOM_VARIABLE_H
19 
20 #include <algorithm>
21 #include <iostream>
22 #include <map>
23 #include <memory>
24 #include <numeric>
25 #include <set>
26 #include <string>
27 #include <unordered_map>
28 #include <vector>
29 
30 namespace android {
31 namespace nn {
32 namespace fuzzing_test {
33 
34 static const int kMaxValue = 10000;
35 static const int kInvalidValue = INT_MIN;
36 
37 // Describe the search range for the value of a random variable.
38 class RandomVariableRange {
39    public:
40     RandomVariableRange() = default;
RandomVariableRange(int value)41     explicit RandomVariableRange(int value) : mChoices({value}) {}
RandomVariableRange(int lower,int upper)42     RandomVariableRange(int lower, int upper) : mChoices(upper - lower + 1) {
43         std::iota(mChoices.begin(), mChoices.end(), lower);
44     }
RandomVariableRange(const std::vector<int> & vec)45     explicit RandomVariableRange(const std::vector<int>& vec) : mChoices(vec) {}
RandomVariableRange(const std::set<int> & st)46     explicit RandomVariableRange(const std::set<int>& st) : mChoices(st.begin(), st.end()) {}
47     RandomVariableRange(const RandomVariableRange&) = default;
48     RandomVariableRange& operator=(const RandomVariableRange&) = default;
49 
empty()50     bool empty() const { return mChoices.empty(); }
has(int value)51     bool has(int value) const {
52         return std::binary_search(mChoices.begin(), mChoices.end(), value);
53     }
size()54     size_t size() const { return mChoices.size(); }
min()55     int min() const { return *mChoices.begin(); }
max()56     int max() const { return *mChoices.rbegin(); }
getChoices()57     const std::vector<int>& getChoices() const { return mChoices; }
58 
59     // Narrow down the range to fit [lower, upper]. Use kInvalidValue to indicate unlimited bound.
60     void setRange(int lower, int upper);
61     // Narrow down the range to a random selected choice. Return the chosen value.
62     int toConst();
63 
64     // Calculate the intersection of two ranges.
65     friend RandomVariableRange operator&(const RandomVariableRange& lhs,
66                                          const RandomVariableRange& rhs);
67 
68    private:
69     // Always in ascending order.
70     std::vector<int> mChoices;
71 };
72 
73 // Defines the interface for an operation applying to RandomVariables.
74 class IRandomVariableOp {
75    public:
~IRandomVariableOp()76     virtual ~IRandomVariableOp() {}
77     // Forward evaluation of two values.
78     virtual int eval(int lhs, int rhs) const = 0;
79     // Gets the range of the operation outcomes. The returned range must include all possible
80     // outcomes of this operation, but may contain invalid results.
81     virtual RandomVariableRange getInitRange(const RandomVariableRange& lhs,
82                                              const RandomVariableRange& rhs) const;
83     // Provides faster range evaluation for evalSubnetSingleOpHelper if possible.
84     virtual void eval(const std::set<int>* parent1In, const std::set<int>* parent2In,
85                       const std::set<int>* childIn, std::set<int>* parent1Out,
86                       std::set<int>* parent2Out, std::set<int>* childOut) const;
87     // For debugging purpose.
88     virtual const char* getName() const = 0;
89 };
90 
91 enum class RandomVariableType { FREE = 0, CONST = 1, OP = 2 };
92 
93 struct RandomVariableBase {
94     // Each RandomVariableBase is assigned an unique index for debugging purpose.
95     static unsigned int globalIndex;
96     int index;
97 
98     RandomVariableType type;
99     RandomVariableRange range;
100     int value = 0;
101     std::shared_ptr<const IRandomVariableOp> op = nullptr;
102 
103     // Network structural information.
104     std::shared_ptr<RandomVariableBase> parent1 = nullptr;
105     std::shared_ptr<RandomVariableBase> parent2 = nullptr;
106     std::vector<std::shared_ptr<RandomVariableBase>> children;
107 
108     // The last time that this RandomVariableBase is modified.
109     int timestamp;
110 
111     explicit RandomVariableBase(int value);
112     RandomVariableBase(int lower, int upper);
113     explicit RandomVariableBase(const std::vector<int>& choices);
114     RandomVariableBase(const std::shared_ptr<RandomVariableBase>& lhs,
115                        const std::shared_ptr<RandomVariableBase>& rhs,
116                        const std::shared_ptr<const IRandomVariableOp>& op);
117     RandomVariableBase(const RandomVariableBase&) = delete;
118     RandomVariableBase& operator=(const RandomVariableBase&) = delete;
119 
120     // Freeze FREE RandomVariable to one valid choice.
121     // Should only invoke on FREE RandomVariable.
122     void freeze();
123 
124     // Get CONST value or calculate from parents.
125     // Should not invoke on FREE RandomVariable.
126     int getValue() const;
127 
128     // Update the timestamp to the latest global time.
129     void updateTimestamp();
130 };
131 
132 using RandomVariableNode = std::shared_ptr<RandomVariableBase>;
133 
134 // A wrapper class of RandomVariableBase that manages RandomVariableBase with shared_ptr and
135 // provides useful methods and operator overloading to build the random variable network.
136 class RandomVariable {
137    public:
138     // Construct a dummy RandomVariable with nullptr.
RandomVariable()139     RandomVariable() : mVar(nullptr) {}
140 
141     // Construct a CONST RandomVariable with specified value.
142     /* implicit */ RandomVariable(int value);
143 
144     // Construct a FREE RandomVariable with range [lower, upper].
145     RandomVariable(int lower, int upper);
146 
147     // Construct a FREE RandomVariable with specified value choices.
148     explicit RandomVariable(const std::vector<int>& choices);
149 
150     // This is for RandomVariableType::FREE only.
151     // Construct a FREE RandomVariable with default range [1, defaultValue].
152     /* implicit */ RandomVariable(RandomVariableType type);
153 
154     // RandomVariables share the same RandomVariableBase if copied or copy-assigned.
155     RandomVariable(const RandomVariable& other) = default;
156     RandomVariable& operator=(const RandomVariable& other) = default;
157 
158     // Get the value of the RandomVariable, the value must be deterministic.
getValue()159     int getValue() const { return mVar->getValue(); }
160 
161     // Get the underlying managed RandomVariableNode.
get()162     RandomVariableNode get() const { return mVar; };
163 
164     bool operator==(nullptr_t) const { return mVar == nullptr; }
165     bool operator!=(nullptr_t) const { return mVar != nullptr; }
166 
167     // Arithmetic operators and methods on RandomVariables.
168     friend RandomVariable operator+(const RandomVariable& lhs, const RandomVariable& rhs);
169     friend RandomVariable operator-(const RandomVariable& lhs, const RandomVariable& rhs);
170     friend RandomVariable operator*(const RandomVariable& lhs, const RandomVariable& rhs);
171     friend RandomVariable operator*(const RandomVariable& lhs, const float& rhs);
172     friend RandomVariable operator/(const RandomVariable& lhs, const RandomVariable& rhs);
173     friend RandomVariable operator%(const RandomVariable& lhs, const RandomVariable& rhs);
174     friend RandomVariable max(const RandomVariable& lhs, const RandomVariable& rhs);
175     friend RandomVariable min(const RandomVariable& lhs, const RandomVariable& rhs);
176     RandomVariable exactDiv(const RandomVariable& other);
177 
178     // Set constraints on the RandomVariable. Use kInvalidValue to indicate unlimited bound.
179     void setRange(int lower, int upper);
180     RandomVariable setEqual(const RandomVariable& other) const;
181     RandomVariable setGreaterThan(const RandomVariable& other) const;
182     RandomVariable setGreaterEqual(const RandomVariable& other) const;
183 
184     // A FREE RandomVariable is constructed with default range [1, defaultValue].
185     static int defaultValue;
186 
187    private:
188     // Construct a RandomVariable as the result of an OP between two other RandomVariables.
189     RandomVariable(const RandomVariable& lhs, const RandomVariable& rhs,
190                    const std::shared_ptr<const IRandomVariableOp>& op);
191     RandomVariableNode mVar;
192 };
193 
194 using EvaluationOrder = std::vector<RandomVariableNode>;
195 
196 // The base class of a network consisting of disjoint subnets.
197 class DisjointNetwork {
198    public:
199     // Add a node to the network, join the parent subnets if needed.
200     void add(const RandomVariableNode& var);
201 
202     // Similar to join(int, int), but accept RandomVariableNodes.
join(const RandomVariableNode & var1,const RandomVariableNode & var2)203     int join(const RandomVariableNode& var1, const RandomVariableNode& var2) {
204         return DisjointNetwork::join(mIndexMap[var1], mIndexMap[var2]);
205     }
206 
207    protected:
208     DisjointNetwork() = default;
209     DisjointNetwork(const DisjointNetwork&) = default;
210     DisjointNetwork& operator=(const DisjointNetwork&) = default;
211 
212     // Join two subnets by appending every node in ind2 after ind1, return the resulting subnet
213     // index. Use -1 for invalid subnet index.
214     int join(int ind1, int ind2);
215 
216     // A map from the network node to the corresponding subnet index.
217     std::unordered_map<RandomVariableNode, int> mIndexMap;
218 
219     // A map from the subnet index to the set of nodes within the subnet. The nodes are maintained
220     // in a valid evaluation order, that is, a valid topological sort.
221     std::unordered_map<int, EvaluationOrder> mEvalOrderMap;
222 
223     // The next index for a new disjoint subnet component.
224     int mNextIndex = 0;
225 };
226 
227 // Manages the active RandomVariable network. Only one instance of this class will exist.
228 class RandomVariableNetwork : public DisjointNetwork {
229    public:
230     // Returns the singleton network instance.
231     static RandomVariableNetwork* get();
232 
233     // Re-initialization. Should be called every time a new random graph is being generated.
234     void initialize(int defaultValue);
235 
236     // Set the elementwise equality of the two vectors of RandomVariables iff it results in a
237     // soluble network.
238     bool setEqualIfCompatible(const std::vector<RandomVariable>& lhs,
239                               const std::vector<RandomVariable>& rhs);
240 
241     // Freeze all FREE RandomVariables in the network to a random valid combination.
242     bool freeze();
243 
244     // Check if node2 is FREE and can be evaluated after node1.
245     bool isSubordinate(const RandomVariableNode& node1, const RandomVariableNode& node2);
246 
247     // Get and then advance the current global timestamp.
getGlobalTime()248     int getGlobalTime() { return mGlobalTime++; }
249 
250     // Add a special constraint on dimension product.
251     void addDimensionProd(const std::vector<RandomVariable>& dims);
252 
253    private:
254     RandomVariableNetwork() = default;
255     RandomVariableNetwork(const RandomVariableNetwork&) = default;
256     RandomVariableNetwork& operator=(const RandomVariableNetwork&) = default;
257 
258     // A class to revert all the changes made to RandomVariableNetwork since the Reverter object is
259     // constructed. Only used when setEqualIfCompatible results in incompatible.
260     class Reverter;
261 
262     // Find valid choices for all RandomVariables in the network. Update the RandomVariableRange
263     // if the network is soluble, otherwise, return false and leave the ranges unchanged.
264     bool evalRange();
265 
266     int mGlobalTime = 0;
267     int mTimestamp = -1;
268 
269     std::vector<EvaluationOrder> mDimProd;
270 };
271 
272 }  // namespace fuzzing_test
273 }  // namespace nn
274 }  // namespace android
275 
276 #endif  // ANDROID_FRAMEWORK_ML_NN_RUNTIME_TEST_FUZZING_RANDOM_VARIABLE_H
277