1 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
2 // -*- mode: C++ -*-
3 //
4 // Copyright 2022-2024 Google LLC
5 //
6 // Licensed under the Apache License v2.0 with LLVM Exceptions (the
7 // "License"); you may not use this file except in compliance with the
8 // License. You may obtain a copy of the License at
9 //
10 // https://llvm.org/LICENSE.txt
11 //
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
17 //
18 // Author: Giuliano Procida
19
20 #include "unification.h"
21
22 #include <cstddef>
23 #include <exception>
24 #include <map>
25 #include <optional>
26 #include <utility>
27 #include <unordered_map>
28 #include <unordered_set>
29 #include <vector>
30
31 #include "graph.h"
32 #include "runtime.h"
33 #include "substitution.h"
34
35 namespace stg {
36
37 namespace {
38
39 // Type Unification
40 //
41 // This is very similar to Equals. The differences are the recursion control,
42 // caching and handling of StructUnion and Enum nodes.
43 //
44 // During unification, keep track of which pairs of types need to be equal, but
45 // do not add them immediately to the unification substitutions. The caller can
46 // do that if the whole unification succeeds.
47 //
48 // A declaration and definition of the same named type can be unified. This is
49 // forward declaration resolution.
50 struct Unifier {
51 enum Winner { Neither, Right, Left }; // makes p ? Right : Neither a no-op
52
Unifierstg::__anon29f890110111::Unifier53 Unifier(const Graph& graph, Unification& unification)
54 : graph(graph), unification(unification) {}
55
operator ()stg::__anon29f890110111::Unifier56 bool operator()(Id id1, Id id2) {
57 Id fid1 = Find(id1);
58 Id fid2 = Find(id2);
59 if (fid1 == fid2) {
60 return true;
61 }
62
63 // Check if the comparison has been (or is being) visited already. We don't
64 // need an SCC finder as any failure to unify will poison the entire DFS.
65 //
66 // This prevents infinite recursion, but maybe not immediately as seen is
67 // unaware of new mappings.
68 if (!seen.emplace(fid1, fid2).second) {
69 return true;
70 }
71
72 const auto winner = graph.Apply2(*this, fid1, fid2);
73 if (winner == Neither) {
74 return false;
75 }
76
77 // These will occasionally get substituted due to a recursive call.
78 fid1 = Find(fid1);
79 fid2 = Find(fid2);
80 if (fid1 == fid2) {
81 return true;
82 }
83
84 if (winner == Left) {
85 std::swap(fid1, fid2);
86 }
87 mapping.insert({fid1, fid2});
88
89 return true;
90 }
91
operator ()stg::__anon29f890110111::Unifier92 bool operator()(const std::optional<Id>& opt1,
93 const std::optional<Id>& opt2) {
94 if (opt1.has_value() && opt2.has_value()) {
95 return (*this)(opt1.value(), opt2.value());
96 }
97 return opt1.has_value() == opt2.has_value();
98 }
99
operator ()stg::__anon29f890110111::Unifier100 bool operator()(const std::vector<Id>& ids1, const std::vector<Id>& ids2) {
101 bool result = ids1.size() == ids2.size();
102 for (size_t ix = 0; result && ix < ids1.size(); ++ix) {
103 result = (*this)(ids1[ix], ids2[ix]);
104 }
105 return result;
106 }
107
108 template <typename Key>
operator ()stg::__anon29f890110111::Unifier109 bool operator()(const std::map<Key, Id>& ids1,
110 const std::map<Key, Id>& ids2) {
111 bool result = ids1.size() == ids2.size();
112 auto it1 = ids1.begin();
113 auto it2 = ids2.begin();
114 const auto end1 = ids1.end();
115 const auto end2 = ids2.end();
116 while (result && it1 != end1 && it2 != end2) {
117 result = it1->first == it2->first
118 && (*this)(it1->second, it2->second);
119 ++it1;
120 ++it2;
121 }
122 return result && it1 == end1 && it2 == end2;
123 }
124
operator ()stg::__anon29f890110111::Unifier125 Winner operator()(const Special& x1, const Special& x2) {
126 return x1.kind == x2.kind
127 ? Right : Neither;
128 }
129
operator ()stg::__anon29f890110111::Unifier130 Winner operator()(const PointerReference& x1,
131 const PointerReference& x2) {
132 return x1.kind == x2.kind
133 && (*this)(x1.pointee_type_id, x2.pointee_type_id)
134 ? Right : Neither;
135 }
136
operator ()stg::__anon29f890110111::Unifier137 Winner operator()(const PointerToMember& x1, const PointerToMember& x2) {
138 return (*this)(x1.containing_type_id, x2.containing_type_id)
139 && (*this)(x1.pointee_type_id, x2.pointee_type_id)
140 ? Right : Neither;
141 }
142
operator ()stg::__anon29f890110111::Unifier143 Winner operator()(const Typedef& x1, const Typedef& x2) {
144 return x1.name == x2.name
145 && (*this)(x1.referred_type_id, x2.referred_type_id)
146 ? Right : Neither;
147 }
148
operator ()stg::__anon29f890110111::Unifier149 Winner operator()(const Qualified& x1, const Qualified& x2) {
150 return x1.qualifier == x2.qualifier
151 && (*this)(x1.qualified_type_id, x2.qualified_type_id)
152 ? Right : Neither;
153 }
154
operator ()stg::__anon29f890110111::Unifier155 Winner operator()(const Primitive& x1, const Primitive& x2) {
156 return x1.name == x2.name
157 && x1.encoding == x2.encoding
158 && x1.bytesize == x2.bytesize
159 ? Right : Neither;
160 }
161
operator ()stg::__anon29f890110111::Unifier162 Winner operator()(const Array& x1, const Array& x2) {
163 return x1.number_of_elements == x2.number_of_elements
164 && (*this)(x1.element_type_id, x2.element_type_id)
165 ? Right : Neither;
166 }
167
operator ()stg::__anon29f890110111::Unifier168 Winner operator()(const BaseClass& x1, const BaseClass& x2) {
169 return x1.offset == x2.offset
170 && x1.inheritance == x2.inheritance
171 && (*this)(x1.type_id, x2.type_id)
172 ? Right : Neither;
173 }
174
operator ()stg::__anon29f890110111::Unifier175 Winner operator()(const Method& x1, const Method& x2) {
176 return x1.mangled_name == x2.mangled_name
177 && x1.name == x2.name
178 && x1.vtable_offset == x2.vtable_offset
179 && (*this)(x1.type_id, x2.type_id)
180 ? Right : Neither;
181 }
182
operator ()stg::__anon29f890110111::Unifier183 Winner operator()(const Member& x1, const Member& x2) {
184 return x1.name == x2.name
185 && x1.offset == x2.offset
186 && x1.bitsize == x2.bitsize
187 && (*this)(x1.type_id, x2.type_id)
188 ? Right : Neither;
189 }
190
operator ()stg::__anon29f890110111::Unifier191 Winner operator()(const VariantMember& x1, const VariantMember& x2) {
192 return x1.name == x2.name
193 && x1.discriminant_value == x2.discriminant_value
194 && (*this)(x1.type_id, x2.type_id)
195 ? Right : Neither;
196 }
197
operator ()stg::__anon29f890110111::Unifier198 Winner operator()(const StructUnion& x1, const StructUnion& x2) {
199 const auto& definition1 = x1.definition;
200 const auto& definition2 = x2.definition;
201 bool result = x1.kind == x2.kind
202 && x1.name == x2.name;
203 // allow mismatches as forward declarations are always unifiable
204 if (result && definition1.has_value() && definition2.has_value()) {
205 result = definition1->bytesize == definition2->bytesize
206 && (*this)(definition1->base_classes, definition2->base_classes)
207 && (*this)(definition1->methods, definition2->methods)
208 && (*this)(definition1->members, definition2->members);
209 }
210 return result ? definition2.has_value() ? Right : Left : Neither;
211 }
212
operator ()stg::__anon29f890110111::Unifier213 Winner operator()(const Enumeration& x1, const Enumeration& x2) {
214 const auto& definition1 = x1.definition;
215 const auto& definition2 = x2.definition;
216 bool result = x1.name == x2.name;
217 // allow mismatches as forward declarations are always unifiable
218 if (result && definition1.has_value() && definition2.has_value()) {
219 result = (*this)(definition1->underlying_type_id,
220 definition2->underlying_type_id)
221 && definition1->enumerators == definition2->enumerators;
222 }
223 return result ? definition2.has_value() ? Right : Left : Neither;
224 }
225
operator ()stg::__anon29f890110111::Unifier226 Winner operator()(const Variant& x1, const Variant& x2) {
227 return x1.name == x2.name
228 && x1.bytesize == x2.bytesize
229 && (*this)(x1.discriminant, x2.discriminant)
230 && (*this)(x1.members, x2.members)
231 ? Right : Neither;
232 }
233
operator ()stg::__anon29f890110111::Unifier234 Winner operator()(const Function& x1, const Function& x2) {
235 return (*this)(x1.parameters, x2.parameters)
236 && (*this)(x1.return_type_id, x2.return_type_id)
237 ? Right : Neither;
238 }
239
operator ()stg::__anon29f890110111::Unifier240 Winner operator()(const ElfSymbol& x1, const ElfSymbol& x2) {
241 bool result = x1.symbol_name == x2.symbol_name
242 && x1.version_info == x2.version_info
243 && x1.is_defined == x2.is_defined
244 && x1.symbol_type == x2.symbol_type
245 && x1.binding == x2.binding
246 && x1.visibility == x2.visibility
247 && x1.crc == x2.crc
248 && x1.ns == x2.ns
249 && x1.full_name == x2.full_name
250 && x1.type_id.has_value() == x2.type_id.has_value();
251 if (result && x1.type_id.has_value()) {
252 result = (*this)(x1.type_id.value(), x2.type_id.value());
253 }
254 return result ? Right : Neither;
255 }
256
operator ()stg::__anon29f890110111::Unifier257 Winner operator()(const Interface& x1, const Interface& x2) {
258 return (*this)(x1.symbols, x2.symbols)
259 && (*this)(x1.types, x2.types)
260 ? Right : Neither;
261 }
262
Mismatchstg::__anon29f890110111::Unifier263 Winner Mismatch() {
264 return Neither;
265 }
266
Findstg::__anon29f890110111::Unifier267 Id Find(Id id) {
268 while (true) {
269 id = unification.Find(id);
270 auto it = mapping.find(id);
271 if (it != mapping.end()) {
272 id = it->second;
273 continue;
274 }
275 return id;
276 }
277 }
278
279 const Graph& graph;
280 Unification& unification;
281 std::unordered_set<Pair> seen;
282 std::unordered_map<Id, Id> mapping;
283 };
284
285 } // namespace
286
Unification(Runtime & runtime,Graph & graph,Id start,Id limit)287 Unification::Unification(Runtime& runtime, Graph& graph, Id start, Id limit)
288 : graph_(graph),
289 start_(start),
290 mapping_(start, limit),
291 runtime_(runtime),
292 find_query_(runtime, "unification.find_query"),
293 find_halved_(runtime, "unification.find_halved"),
294 union_known_(runtime, "unification.union_known"),
295 union_unknown_(runtime, "unification.union_unknown") {}
296
~Unification()297 Unification::~Unification() noexcept(false) {
298 if (std::uncaught_exceptions() > 0) {
299 // abort unification
300 return;
301 }
302 // apply substitutions to the entire graph
303 const Time time(runtime_, "unification.rewrite");
304 Counter removed(runtime_, "unification.removed");
305 Counter retained(runtime_, "unification.retained");
306 const auto remap = [&](Id& id) {
307 // update id to representative id, avoiding silent stores
308 const Id fid = Find(id);
309 if (fid != id) {
310 id = fid;
311 }
312 };
313 const Substitute substitute(graph_, remap);
314 graph_.ForEach(start_, graph_.Limit(), [&](Id id) {
315 if (Find(id) != id) {
316 graph_.Remove(id);
317 ++removed;
318 } else {
319 substitute(id);
320 ++retained;
321 }
322 });
323 }
324
Union(Id id1,Id id2)325 void Unification::Union(Id id1, Id id2) {
326 // always prefer Find(id2) as a parent
327 const Id fid1 = Find(id1);
328 const Id fid2 = Find(id2);
329 if (fid1 == fid2) {
330 ++union_known_;
331 return;
332 }
333 mapping_[fid1] = fid2;
334 ++union_unknown_;
335 }
336
Find(Id id)337 Id Unification::Find(Id id) {
338 ++find_query_;
339 // path halving - tiny performance gain
340 while (true) {
341 // note: safe to take a reference as mapping cannot grow after this
342 auto& parent = mapping_[id];
343 if (parent == id) {
344 return id;
345 }
346 const auto parent_parent = mapping_[parent];
347 if (parent_parent == parent) {
348 return parent;
349 }
350 id = parent = parent_parent;
351 ++find_halved_;
352 }
353 }
354
Unify(Id id1,Id id2)355 bool Unification::Unify(Id id1, Id id2) {
356 // TODO: Unifier only needs access to Unification::Find
357 Unifier unifier(graph_, *this);
358 if (unifier(id1, id2)) {
359 // commit
360 for (const auto& s : unifier.mapping) {
361 Union(s.first, s.second);
362 }
363 return true;
364 }
365 return false;
366 }
367
368 } // namespace stg
369