• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef SOURCE_FUZZ_EQUIVALENCE_RELATION_H_
16 #define SOURCE_FUZZ_EQUIVALENCE_RELATION_H_
17 
18 #include <algorithm>
19 #include <cassert>
20 #include <memory>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <vector>
24 
25 #include "source/util/make_unique.h"
26 
27 namespace spvtools {
28 namespace fuzz {
29 
30 // A class for representing an equivalence relation on objects of type |T|,
31 // which should be a value type.  The type |T| is required to have a copy
32 // constructor, and |PointerHashT| and |PointerEqualsT| must be functors
33 // providing hashing and equality testing functionality for pointers to objects
34 // of type |T|.
35 //
36 // A disjoint-set (a.k.a. union-find or merge-find) data structure is used to
37 // represent the equivalence relation.  Path compression is used.  Union by
38 // rank/size is not used.
39 //
40 // Each disjoint set is represented as a tree, rooted at the representative
41 // of the set.
42 //
43 // Getting the representative of a value simply requires chasing parent pointers
44 // from the value until you reach the root.
45 //
46 // Checking equivalence of two elements requires checking that the
47 // representatives are equal.
48 //
49 // Traversing the tree rooted at a value's representative visits the value's
50 // equivalence class.
51 //
52 // |PointerHashT| and |PointerEqualsT| are used to define *equality* between
53 // values, and otherwise are *not* used to define the equivalence relation
54 // (except that equal values are equivalent).  The equivalence relation is
55 // constructed by repeatedly adding pairs of (typically non-equal) values that
56 // are deemed to be equivalent.
57 //
58 // For example in an equivalence relation on integers, 1 and 5 might be added
59 // as equivalent, so that IsEquivalent(1, 5) holds, because they represent
60 // IDs in a SPIR-V binary that are known to contain the same value at run time,
61 // but clearly 1 != 5.  Since 1 and 1 are equal, IsEquivalent(1, 1) will also
62 // hold.
63 //
64 // Each unique (up to equality) value added to the relation is copied into
65 // |owned_values_|, so there is one canonical memory address per unique value.
66 // Uniqueness is ensured by storing (and checking) a set of pointers to these
67 // values in |value_set_|, which uses |PointerHashT| and |PointerEqualsT|.
68 //
69 // |parent_| and |children_| encode the equivalence relation, i.e., the trees.
70 template <typename T, typename PointerHashT, typename PointerEqualsT>
71 class EquivalenceRelation {
72  public:
73   // Requires that |value1| and |value2| are already registered in the
74   // equivalence relation.  Merges the equivalence classes associated with
75   // |value1| and |value2|.
MakeEquivalent(const T & value1,const T & value2)76   void MakeEquivalent(const T& value1, const T& value2) {
77     assert(Exists(value1) &&
78            "Precondition: value1 must already be registered.");
79     assert(Exists(value2) &&
80            "Precondition: value2 must already be registered.");
81 
82     // Look up canonical pointers to each of the values in the value pool.
83     const T* value1_ptr = *value_set_.find(&value1);
84     const T* value2_ptr = *value_set_.find(&value2);
85 
86     // If the values turn out to be identical, they are already in the same
87     // equivalence class so there is nothing to do.
88     if (value1_ptr == value2_ptr) {
89       return;
90     }
91 
92     // Find the representative for each value's equivalence class, and if they
93     // are not already in the same class, make one the parent of the other.
94     const T* representative1 = Find(value1_ptr);
95     const T* representative2 = Find(value2_ptr);
96     assert(representative1 && "Representatives should never be null.");
97     assert(representative2 && "Representatives should never be null.");
98     if (representative1 != representative2) {
99       parent_[representative1] = representative2;
100       children_[representative2].push_back(representative1);
101     }
102   }
103 
104   // Requires that |value| is not known to the equivalence relation. Registers
105   // it in its own equivalence class and returns a pointer to the equivalence
106   // class representative.
Register(const T & value)107   const T* Register(const T& value) {
108     assert(!Exists(value));
109 
110     // This relies on T having a copy constructor.
111     auto unique_pointer_to_value = MakeUnique<T>(value);
112     auto pointer_to_value = unique_pointer_to_value.get();
113     owned_values_.push_back(std::move(unique_pointer_to_value));
114     value_set_.insert(pointer_to_value);
115 
116     // Initially say that the value is its own parent and that it has no
117     // children.
118     assert(pointer_to_value && "Representatives should never be null.");
119     parent_[pointer_to_value] = pointer_to_value;
120     children_[pointer_to_value] = std::vector<const T*>();
121 
122     return pointer_to_value;
123   }
124 
125   // Returns exactly one representative per equivalence class.
GetEquivalenceClassRepresentatives()126   std::vector<const T*> GetEquivalenceClassRepresentatives() const {
127     std::vector<const T*> result;
128     for (auto& value : owned_values_) {
129       if (parent_[value.get()] == value.get()) {
130         result.push_back(value.get());
131       }
132     }
133     return result;
134   }
135 
136   // Returns pointers to all values in the equivalence class of |value|, which
137   // must already be part of the equivalence relation.
GetEquivalenceClass(const T & value)138   std::vector<const T*> GetEquivalenceClass(const T& value) const {
139     assert(Exists(value));
140 
141     std::vector<const T*> result;
142 
143     // Traverse the tree of values rooted at the representative of the
144     // equivalence class to which |value| belongs, and collect up all the values
145     // that are encountered.  This constitutes the whole equivalence class.
146     std::vector<const T*> stack;
147     stack.push_back(Find(*value_set_.find(&value)));
148     while (!stack.empty()) {
149       const T* item = stack.back();
150       result.push_back(item);
151       stack.pop_back();
152       for (auto child : children_[item]) {
153         stack.push_back(child);
154       }
155     }
156     return result;
157   }
158 
159   // Returns true if and only if |value1| and |value2| are in the same
160   // equivalence class.  Both values must already be known to the equivalence
161   // relation.
IsEquivalent(const T & value1,const T & value2)162   bool IsEquivalent(const T& value1, const T& value2) const {
163     return Find(&value1) == Find(&value2);
164   }
165 
166   // Returns all values known to be part of the equivalence relation.
GetAllKnownValues()167   std::vector<const T*> GetAllKnownValues() const {
168     std::vector<const T*> result;
169     for (auto& value : owned_values_) {
170       result.push_back(value.get());
171     }
172     return result;
173   }
174 
175   // Returns true if and only if |value| is known to be part of the equivalence
176   // relation.
Exists(const T & value)177   bool Exists(const T& value) const {
178     return value_set_.find(&value) != value_set_.end();
179   }
180 
181   // Returns the representative of the equivalence class of |value|, which must
182   // already be known to the equivalence relation.  This is the 'Find' operation
183   // in a classic union-find data structure.
Find(const T * value)184   const T* Find(const T* value) const {
185     assert(Exists(*value));
186 
187     // Get the canonical pointer to the value from the value pool.
188     const T* known_value = *value_set_.find(value);
189     assert(parent_[known_value] && "Every known value should have a parent.");
190 
191     // Compute the result by chasing parents until we find a value that is its
192     // own parent.
193     const T* result = known_value;
194     while (parent_[result] != result) {
195       result = parent_[result];
196     }
197     assert(result && "Representatives should never be null.");
198 
199     // At this point, |result| is the representative of the equivalence class.
200     // Now perform the 'path compression' optimization by doing another pass up
201     // the parent chain, setting the parent of each node to be the
202     // representative, and rewriting children correspondingly.
203     const T* current = known_value;
204     while (parent_[current] != result) {
205       const T* next = parent_[current];
206       parent_[current] = result;
207       children_[result].push_back(current);
208       auto child_iterator =
209           std::find(children_[next].begin(), children_[next].end(), current);
210       assert(child_iterator != children_[next].end() &&
211              "'next' is the parent of 'current', so 'current' should be a "
212              "child of 'next'");
213       children_[next].erase(child_iterator);
214       current = next;
215     }
216     return result;
217   }
218 
219  private:
220   // Maps every value to a parent.  The representative of an equivalence class
221   // is its own parent.  A value's representative can be found by walking its
222   // chain of ancestors.
223   //
224   // Mutable because the intuitively const method, 'Find', performs path
225   // compression.
226   mutable std::unordered_map<const T*, const T*> parent_;
227 
228   // Stores the children of each value.  This allows the equivalence class of
229   // a value to be calculated by traversing all descendents of the class's
230   // representative.
231   //
232   // Mutable because the intuitively const method, 'Find', performs path
233   // compression.
234   mutable std::unordered_map<const T*, std::vector<const T*>> children_;
235 
236   // The values known to the equivalence relation are allocated in
237   // |owned_values_|, and |value_pool_| provides (via |PointerHashT| and
238   // |PointerEqualsT|) a means for mapping a value of interest to a pointer
239   // into an equivalent value in |owned_values_|.
240   std::unordered_set<const T*, PointerHashT, PointerEqualsT> value_set_;
241   std::vector<std::unique_ptr<T>> owned_values_;
242 };
243 
244 }  // namespace fuzz
245 }  // namespace spvtools
246 
247 #endif  // SOURCE_FUZZ_EQUIVALENCE_RELATION_H_
248