1 // Copyright (c) 2019 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef SOURCE_FUZZ_EQUIVALENCE_RELATION_H_ 16 #define SOURCE_FUZZ_EQUIVALENCE_RELATION_H_ 17 18 #include <algorithm> 19 #include <cassert> 20 #include <memory> 21 #include <unordered_map> 22 #include <unordered_set> 23 #include <vector> 24 25 #include "source/util/make_unique.h" 26 27 namespace spvtools { 28 namespace fuzz { 29 30 // A class for representing an equivalence relation on objects of type |T|, 31 // which should be a value type. The type |T| is required to have a copy 32 // constructor, and |PointerHashT| and |PointerEqualsT| must be functors 33 // providing hashing and equality testing functionality for pointers to objects 34 // of type |T|. 35 // 36 // A disjoint-set (a.k.a. union-find or merge-find) data structure is used to 37 // represent the equivalence relation. Path compression is used. Union by 38 // rank/size is not used. 39 // 40 // Each disjoint set is represented as a tree, rooted at the representative 41 // of the set. 42 // 43 // Getting the representative of a value simply requires chasing parent pointers 44 // from the value until you reach the root. 45 // 46 // Checking equivalence of two elements requires checking that the 47 // representatives are equal. 48 // 49 // Traversing the tree rooted at a value's representative visits the value's 50 // equivalence class. 51 // 52 // |PointerHashT| and |PointerEqualsT| are used to define *equality* between 53 // values, and otherwise are *not* used to define the equivalence relation 54 // (except that equal values are equivalent). The equivalence relation is 55 // constructed by repeatedly adding pairs of (typically non-equal) values that 56 // are deemed to be equivalent. 57 // 58 // For example in an equivalence relation on integers, 1 and 5 might be added 59 // as equivalent, so that IsEquivalent(1, 5) holds, because they represent 60 // IDs in a SPIR-V binary that are known to contain the same value at run time, 61 // but clearly 1 != 5. Since 1 and 1 are equal, IsEquivalent(1, 1) will also 62 // hold. 63 // 64 // Each unique (up to equality) value added to the relation is copied into 65 // |owned_values_|, so there is one canonical memory address per unique value. 66 // Uniqueness is ensured by storing (and checking) a set of pointers to these 67 // values in |value_set_|, which uses |PointerHashT| and |PointerEqualsT|. 68 // 69 // |parent_| and |children_| encode the equivalence relation, i.e., the trees. 70 template <typename T, typename PointerHashT, typename PointerEqualsT> 71 class EquivalenceRelation { 72 public: 73 // Requires that |value1| and |value2| are already registered in the 74 // equivalence relation. Merges the equivalence classes associated with 75 // |value1| and |value2|. MakeEquivalent(const T & value1,const T & value2)76 void MakeEquivalent(const T& value1, const T& value2) { 77 assert(Exists(value1) && 78 "Precondition: value1 must already be registered."); 79 assert(Exists(value2) && 80 "Precondition: value2 must already be registered."); 81 82 // Look up canonical pointers to each of the values in the value pool. 83 const T* value1_ptr = *value_set_.find(&value1); 84 const T* value2_ptr = *value_set_.find(&value2); 85 86 // If the values turn out to be identical, they are already in the same 87 // equivalence class so there is nothing to do. 88 if (value1_ptr == value2_ptr) { 89 return; 90 } 91 92 // Find the representative for each value's equivalence class, and if they 93 // are not already in the same class, make one the parent of the other. 94 const T* representative1 = Find(value1_ptr); 95 const T* representative2 = Find(value2_ptr); 96 assert(representative1 && "Representatives should never be null."); 97 assert(representative2 && "Representatives should never be null."); 98 if (representative1 != representative2) { 99 parent_[representative1] = representative2; 100 children_[representative2].push_back(representative1); 101 } 102 } 103 104 // Requires that |value| is not known to the equivalence relation. Registers 105 // it in its own equivalence class and returns a pointer to the equivalence 106 // class representative. Register(const T & value)107 const T* Register(const T& value) { 108 assert(!Exists(value)); 109 110 // This relies on T having a copy constructor. 111 auto unique_pointer_to_value = MakeUnique<T>(value); 112 auto pointer_to_value = unique_pointer_to_value.get(); 113 owned_values_.push_back(std::move(unique_pointer_to_value)); 114 value_set_.insert(pointer_to_value); 115 116 // Initially say that the value is its own parent and that it has no 117 // children. 118 assert(pointer_to_value && "Representatives should never be null."); 119 parent_[pointer_to_value] = pointer_to_value; 120 children_[pointer_to_value] = std::vector<const T*>(); 121 122 return pointer_to_value; 123 } 124 125 // Returns exactly one representative per equivalence class. GetEquivalenceClassRepresentatives()126 std::vector<const T*> GetEquivalenceClassRepresentatives() const { 127 std::vector<const T*> result; 128 for (auto& value : owned_values_) { 129 if (parent_[value.get()] == value.get()) { 130 result.push_back(value.get()); 131 } 132 } 133 return result; 134 } 135 136 // Returns pointers to all values in the equivalence class of |value|, which 137 // must already be part of the equivalence relation. GetEquivalenceClass(const T & value)138 std::vector<const T*> GetEquivalenceClass(const T& value) const { 139 assert(Exists(value)); 140 141 std::vector<const T*> result; 142 143 // Traverse the tree of values rooted at the representative of the 144 // equivalence class to which |value| belongs, and collect up all the values 145 // that are encountered. This constitutes the whole equivalence class. 146 std::vector<const T*> stack; 147 stack.push_back(Find(*value_set_.find(&value))); 148 while (!stack.empty()) { 149 const T* item = stack.back(); 150 result.push_back(item); 151 stack.pop_back(); 152 for (auto child : children_[item]) { 153 stack.push_back(child); 154 } 155 } 156 return result; 157 } 158 159 // Returns true if and only if |value1| and |value2| are in the same 160 // equivalence class. Both values must already be known to the equivalence 161 // relation. IsEquivalent(const T & value1,const T & value2)162 bool IsEquivalent(const T& value1, const T& value2) const { 163 return Find(&value1) == Find(&value2); 164 } 165 166 // Returns all values known to be part of the equivalence relation. GetAllKnownValues()167 std::vector<const T*> GetAllKnownValues() const { 168 std::vector<const T*> result; 169 for (auto& value : owned_values_) { 170 result.push_back(value.get()); 171 } 172 return result; 173 } 174 175 // Returns true if and only if |value| is known to be part of the equivalence 176 // relation. Exists(const T & value)177 bool Exists(const T& value) const { 178 return value_set_.find(&value) != value_set_.end(); 179 } 180 181 // Returns the representative of the equivalence class of |value|, which must 182 // already be known to the equivalence relation. This is the 'Find' operation 183 // in a classic union-find data structure. Find(const T * value)184 const T* Find(const T* value) const { 185 assert(Exists(*value)); 186 187 // Get the canonical pointer to the value from the value pool. 188 const T* known_value = *value_set_.find(value); 189 assert(parent_[known_value] && "Every known value should have a parent."); 190 191 // Compute the result by chasing parents until we find a value that is its 192 // own parent. 193 const T* result = known_value; 194 while (parent_[result] != result) { 195 result = parent_[result]; 196 } 197 assert(result && "Representatives should never be null."); 198 199 // At this point, |result| is the representative of the equivalence class. 200 // Now perform the 'path compression' optimization by doing another pass up 201 // the parent chain, setting the parent of each node to be the 202 // representative, and rewriting children correspondingly. 203 const T* current = known_value; 204 while (parent_[current] != result) { 205 const T* next = parent_[current]; 206 parent_[current] = result; 207 children_[result].push_back(current); 208 auto child_iterator = 209 std::find(children_[next].begin(), children_[next].end(), current); 210 assert(child_iterator != children_[next].end() && 211 "'next' is the parent of 'current', so 'current' should be a " 212 "child of 'next'"); 213 children_[next].erase(child_iterator); 214 current = next; 215 } 216 return result; 217 } 218 219 private: 220 // Maps every value to a parent. The representative of an equivalence class 221 // is its own parent. A value's representative can be found by walking its 222 // chain of ancestors. 223 // 224 // Mutable because the intuitively const method, 'Find', performs path 225 // compression. 226 mutable std::unordered_map<const T*, const T*> parent_; 227 228 // Stores the children of each value. This allows the equivalence class of 229 // a value to be calculated by traversing all descendents of the class's 230 // representative. 231 // 232 // Mutable because the intuitively const method, 'Find', performs path 233 // compression. 234 mutable std::unordered_map<const T*, std::vector<const T*>> children_; 235 236 // The values known to the equivalence relation are allocated in 237 // |owned_values_|, and |value_pool_| provides (via |PointerHashT| and 238 // |PointerEqualsT|) a means for mapping a value of interest to a pointer 239 // into an equivalent value in |owned_values_|. 240 std::unordered_set<const T*, PointerHashT, PointerEqualsT> value_set_; 241 std::vector<std::unique_ptr<T>> owned_values_; 242 }; 243 244 } // namespace fuzz 245 } // namespace spvtools 246 247 #endif // SOURCE_FUZZ_EQUIVALENCE_RELATION_H_ 248