• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
2 // -*- mode: C++ -*-
3 //
4 // Copyright 2022-2023 Google LLC
5 //
6 // Licensed under the Apache License v2.0 with LLVM Exceptions (the
7 // "License"); you may not use this file except in compliance with the
8 // License.  You may obtain a copy of the License at
9 //
10 //     https://llvm.org/LICENSE.txt
11 //
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
17 //
18 // Author: Giuliano Procida
19 
20 #ifndef STG_UNIFICATION_H_
21 #define STG_UNIFICATION_H_
22 
23 #include <exception>
24 #include <unordered_map>
25 #include <unordered_set>
26 
27 #include "graph.h"
28 #include "runtime.h"
29 #include "substitution.h"
30 
31 namespace stg {
32 
33 // Keep track of which nodes are pending substitution and rewrite the graph on
34 // destruction.
35 class Unification {
36  public:
Unification(Runtime & runtime,Graph & graph,Id start)37   Unification(Runtime& runtime, Graph& graph, Id start)
38       : graph_(graph),
39         start_(start),
40         mapping_(start),
41         runtime_(runtime),
42         find_query_(runtime, "unification.find_query"),
43         find_halved_(runtime, "unification.find_halved"),
44         union_known_(runtime, "unification.union_known"),
45         union_unknown_(runtime, "unification.union_unknown") {}
46 
~Unification()47   ~Unification() {
48     if (std::uncaught_exceptions() > 0) {
49       // abort unification
50       return;
51     }
52     // apply substitutions to the entire graph
53     const Time time(runtime_, "unification.rewrite");
54     Counter removed(runtime_, "unification.removed");
55     Counter retained(runtime_, "unification.retained");
56     auto remap = [&](Id& id) {
57       Update(id);
58     };
59     ::stg::Substitute substitute(graph_, remap);
60     graph_.ForEach(start_, graph_.Limit(), [&](Id id) {
61       if (Find(id) != id) {
62         graph_.Remove(id);
63         ++removed;
64       } else {
65         substitute(id);
66         ++retained;
67       }
68     });
69   }
70 
Reserve(Id limit)71   void Reserve(Id limit) {
72     mapping_.Reserve(limit);
73   }
74 
75   bool Unify(Id id1, Id id2);
76 
Find(Id id)77   Id Find(Id id) {
78     ++find_query_;
79     // path halving - tiny performance gain
80     while (true) {
81       // note: safe to take references as mapping cannot grow after this
82       auto& parent = mapping_[id];
83       if (parent == id) {
84         return id;
85       }
86       auto& parent_parent = mapping_[parent];
87       if (parent_parent == parent) {
88         return parent;
89       }
90       id = parent = parent_parent;
91       ++find_halved_;
92     }
93   }
94 
Union(Id id1,Id id2)95   void Union(Id id1, Id id2) {
96     // id2 will always be preferred as a parent node; interpreted as a
97     // substitution, id1 will be replaced by id2
98     const Id fid1 = Find(id1);
99     const Id fid2 = Find(id2);
100     if (fid1 == fid2) {
101       ++union_known_;
102       return;
103     }
104     mapping_[fid1] = fid2;
105     ++union_unknown_;
106   }
107 
108   // update id to representative id
Update(Id & id)109   void Update(Id& id) {
110     const Id fid = Find(id);
111     // avoid silent stores
112     if (fid != id) {
113       id = fid;
114     }
115   }
116 
117  private:
118   Graph& graph_;
119   Id start_;
120   DenseIdMapping mapping_;
121   Runtime& runtime_;
122   Counter find_query_;
123   Counter find_halved_;
124   Counter union_known_;
125   Counter union_unknown_;
126 };
127 
128 }  // namespace stg
129 
130 #endif  // STG_UNIFICATION_H_
131