1 // Copyright (c) 2019 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef SOURCE_FUZZ_EQUIVALENCE_RELATION_H_ 16 #define SOURCE_FUZZ_EQUIVALENCE_RELATION_H_ 17 18 #include <memory> 19 #include <unordered_map> 20 #include <unordered_set> 21 #include <vector> 22 23 #include "source/util/make_unique.h" 24 25 namespace spvtools { 26 namespace fuzz { 27 28 // A class for representing an equivalence relation on objects of type |T|, 29 // which should be a value type. The type |T| is required to have a copy 30 // constructor, and |PointerHashT| and |PointerEqualsT| must be functors 31 // providing hashing and equality testing functionality for pointers to objects 32 // of type |T|. 33 // 34 // A disjoint-set (a.k.a. union-find or merge-find) data structure is used to 35 // represent the equivalence relation. Path compression is used. Union by 36 // rank/size is not used. 37 // 38 // Each disjoint set is represented as a tree, rooted at the representative 39 // of the set. 40 // 41 // Getting the representative of a value simply requires chasing parent pointers 42 // from the value until you reach the root. 43 // 44 // Checking equivalence of two elements requires checking that the 45 // representatives are equal. 46 // 47 // Traversing the tree rooted at a value's representative visits the value's 48 // equivalence class. 49 // 50 // |PointerHashT| and |PointerEqualsT| are used to define *equality* between 51 // values, and otherwise are *not* used to define the equivalence relation 52 // (except that equal values are equivalent). The equivalence relation is 53 // constructed by repeatedly adding pairs of (typically non-equal) values that 54 // are deemed to be equivalent. 55 // 56 // For example in an equivalence relation on integers, 1 and 5 might be added 57 // as equivalent, so that IsEquivalent(1, 5) holds, because they represent 58 // IDs in a SPIR-V binary that are known to contain the same value at run time, 59 // but clearly 1 != 5. Since 1 and 1 are equal, IsEquivalent(1, 1) will also 60 // hold. 61 // 62 // Each unique (up to equality) value added to the relation is copied into 63 // |owned_values_|, so there is one canonical memory address per unique value. 64 // Uniqueness is ensured by storing (and checking) a set of pointers to these 65 // values in |value_set_|, which uses |PointerHashT| and |PointerEqualsT|. 66 // 67 // |parent_| and |children_| encode the equivalence relation, i.e., the trees. 68 template <typename T, typename PointerHashT, typename PointerEqualsT> 69 class EquivalenceRelation { 70 public: 71 // Requires that |value1| and |value2| are already registered in the 72 // equivalence relation. Merges the equivalence classes associated with 73 // |value1| and |value2|. MakeEquivalent(const T & value1,const T & value2)74 void MakeEquivalent(const T& value1, const T& value2) { 75 assert(Exists(value1) && 76 "Precondition: value1 must already be registered."); 77 assert(Exists(value2) && 78 "Precondition: value2 must already be registered."); 79 80 // Look up canonical pointers to each of the values in the value pool. 81 const T* value1_ptr = *value_set_.find(&value1); 82 const T* value2_ptr = *value_set_.find(&value2); 83 84 // If the values turn out to be identical, they are already in the same 85 // equivalence class so there is nothing to do. 86 if (value1_ptr == value2_ptr) { 87 return; 88 } 89 90 // Find the representative for each value's equivalence class, and if they 91 // are not already in the same class, make one the parent of the other. 92 const T* representative1 = Find(value1_ptr); 93 const T* representative2 = Find(value2_ptr); 94 assert(representative1 && "Representatives should never be null."); 95 assert(representative2 && "Representatives should never be null."); 96 if (representative1 != representative2) { 97 parent_[representative1] = representative2; 98 children_[representative2].push_back(representative1); 99 } 100 } 101 102 // Requires that |value| is not known to the equivalence relation. Registers 103 // it in its own equivalence class and returns a pointer to the equivalence 104 // class representative. Register(const T & value)105 const T* Register(const T& value) { 106 assert(!Exists(value)); 107 108 // This relies on T having a copy constructor. 109 auto unique_pointer_to_value = MakeUnique<T>(value); 110 auto pointer_to_value = unique_pointer_to_value.get(); 111 owned_values_.push_back(std::move(unique_pointer_to_value)); 112 value_set_.insert(pointer_to_value); 113 114 // Initially say that the value is its own parent and that it has no 115 // children. 116 assert(pointer_to_value && "Representatives should never be null."); 117 parent_[pointer_to_value] = pointer_to_value; 118 children_[pointer_to_value] = std::vector<const T*>(); 119 120 return pointer_to_value; 121 } 122 123 // Returns exactly one representative per equivalence class. GetEquivalenceClassRepresentatives()124 std::vector<const T*> GetEquivalenceClassRepresentatives() const { 125 std::vector<const T*> result; 126 for (auto& value : owned_values_) { 127 if (parent_[value.get()] == value.get()) { 128 result.push_back(value.get()); 129 } 130 } 131 return result; 132 } 133 134 // Returns pointers to all values in the equivalence class of |value|, which 135 // must already be part of the equivalence relation. GetEquivalenceClass(const T & value)136 std::vector<const T*> GetEquivalenceClass(const T& value) const { 137 assert(Exists(value)); 138 139 std::vector<const T*> result; 140 141 // Traverse the tree of values rooted at the representative of the 142 // equivalence class to which |value| belongs, and collect up all the values 143 // that are encountered. This constitutes the whole equivalence class. 144 std::vector<const T*> stack; 145 stack.push_back(Find(*value_set_.find(&value))); 146 while (!stack.empty()) { 147 const T* item = stack.back(); 148 result.push_back(item); 149 stack.pop_back(); 150 for (auto child : children_[item]) { 151 stack.push_back(child); 152 } 153 } 154 return result; 155 } 156 157 // Returns true if and only if |value1| and |value2| are in the same 158 // equivalence class. Both values must already be known to the equivalence 159 // relation. IsEquivalent(const T & value1,const T & value2)160 bool IsEquivalent(const T& value1, const T& value2) const { 161 return Find(&value1) == Find(&value2); 162 } 163 164 // Returns all values known to be part of the equivalence relation. GetAllKnownValues()165 std::vector<const T*> GetAllKnownValues() const { 166 std::vector<const T*> result; 167 for (auto& value : owned_values_) { 168 result.push_back(value.get()); 169 } 170 return result; 171 } 172 173 // Returns true if and only if |value| is known to be part of the equivalence 174 // relation. Exists(const T & value)175 bool Exists(const T& value) const { 176 return value_set_.find(&value) != value_set_.end(); 177 } 178 179 // Returns the representative of the equivalence class of |value|, which must 180 // already be known to the equivalence relation. This is the 'Find' operation 181 // in a classic union-find data structure. Find(const T * value)182 const T* Find(const T* value) const { 183 assert(Exists(*value)); 184 185 // Get the canonical pointer to the value from the value pool. 186 const T* known_value = *value_set_.find(value); 187 assert(parent_[known_value] && "Every known value should have a parent."); 188 189 // Compute the result by chasing parents until we find a value that is its 190 // own parent. 191 const T* result = known_value; 192 while (parent_[result] != result) { 193 result = parent_[result]; 194 } 195 assert(result && "Representatives should never be null."); 196 197 // At this point, |result| is the representative of the equivalence class. 198 // Now perform the 'path compression' optimization by doing another pass up 199 // the parent chain, setting the parent of each node to be the 200 // representative, and rewriting children correspondingly. 201 const T* current = known_value; 202 while (parent_[current] != result) { 203 const T* next = parent_[current]; 204 parent_[current] = result; 205 children_[result].push_back(current); 206 auto child_iterator = 207 std::find(children_[next].begin(), children_[next].end(), current); 208 assert(child_iterator != children_[next].end() && 209 "'next' is the parent of 'current', so 'current' should be a " 210 "child of 'next'"); 211 children_[next].erase(child_iterator); 212 current = next; 213 } 214 return result; 215 } 216 217 private: 218 // Maps every value to a parent. The representative of an equivalence class 219 // is its own parent. A value's representative can be found by walking its 220 // chain of ancestors. 221 // 222 // Mutable because the intuitively const method, 'Find', performs path 223 // compression. 224 mutable std::unordered_map<const T*, const T*> parent_; 225 226 // Stores the children of each value. This allows the equivalence class of 227 // a value to be calculated by traversing all descendents of the class's 228 // representative. 229 // 230 // Mutable because the intuitively const method, 'Find', performs path 231 // compression. 232 mutable std::unordered_map<const T*, std::vector<const T*>> children_; 233 234 // The values known to the equivalence relation are allocated in 235 // |owned_values_|, and |value_pool_| provides (via |PointerHashT| and 236 // |PointerEqualsT|) a means for mapping a value of interest to a pointer 237 // into an equivalent value in |owned_values_|. 238 std::unordered_set<const T*, PointerHashT, PointerEqualsT> value_set_; 239 std::vector<std::unique_ptr<T>> owned_values_; 240 }; 241 242 } // namespace fuzz 243 } // namespace spvtools 244 245 #endif // SOURCE_FUZZ_EQUIVALENCE_RELATION_H_ 246