1 // Copyright (c) 2022 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #ifndef SOURCE_DIFF_LCS_H_
16 #define SOURCE_DIFF_LCS_H_
17
18 #include <algorithm>
19 #include <cassert>
20 #include <cstddef>
21 #include <cstdint>
22 #include <functional>
23 #include <stack>
24 #include <vector>
25
26 namespace spvtools {
27 namespace diff {
28
29 // The result of a diff.
30 using DiffMatch = std::vector<bool>;
31
32 // Helper class to find the longest common subsequence between two function
33 // bodies.
34 template <typename Sequence>
35 class LongestCommonSubsequence {
36 public:
LongestCommonSubsequence(const Sequence & src,const Sequence & dst)37 LongestCommonSubsequence(const Sequence& src, const Sequence& dst)
38 : src_(src),
39 dst_(dst),
40 table_(src.size(), std::vector<DiffMatchEntry>(dst.size())) {}
41
42 // Given two sequences, it creates a matching between them. The elements are
43 // simply marked as matched in src and dst, with any unmatched element in src
44 // implying a removal and any unmatched element in dst implying an addition.
45 //
46 // Returns the length of the longest common subsequence.
47 template <typename T>
48 uint32_t Get(std::function<bool(T src_elem, T dst_elem)> match,
49 DiffMatch* src_match_result, DiffMatch* dst_match_result);
50
51 private:
52 struct DiffMatchIndex {
53 uint32_t src_offset;
54 uint32_t dst_offset;
55 };
56
57 template <typename T>
58 void CalculateLCS(std::function<bool(T src_elem, T dst_elem)> match);
59 void RetrieveMatch(DiffMatch* src_match_result, DiffMatch* dst_match_result);
IsInBound(DiffMatchIndex index)60 bool IsInBound(DiffMatchIndex index) {
61 return index.src_offset < src_.size() && index.dst_offset < dst_.size();
62 }
IsCalculated(DiffMatchIndex index)63 bool IsCalculated(DiffMatchIndex index) {
64 assert(IsInBound(index));
65 return table_[index.src_offset][index.dst_offset].valid;
66 }
IsCalculatedOrOutOfBound(DiffMatchIndex index)67 bool IsCalculatedOrOutOfBound(DiffMatchIndex index) {
68 return !IsInBound(index) || IsCalculated(index);
69 }
GetMemoizedLength(DiffMatchIndex index)70 uint32_t GetMemoizedLength(DiffMatchIndex index) {
71 if (!IsInBound(index)) {
72 return 0;
73 }
74 assert(IsCalculated(index));
75 return table_[index.src_offset][index.dst_offset].best_match_length;
76 }
IsMatched(DiffMatchIndex index)77 bool IsMatched(DiffMatchIndex index) {
78 assert(IsCalculated(index));
79 return table_[index.src_offset][index.dst_offset].matched;
80 }
MarkMatched(DiffMatchIndex index,uint32_t best_match_length,bool matched)81 void MarkMatched(DiffMatchIndex index, uint32_t best_match_length,
82 bool matched) {
83 assert(IsInBound(index));
84 DiffMatchEntry& entry = table_[index.src_offset][index.dst_offset];
85 assert(!entry.valid);
86
87 entry.best_match_length = best_match_length & 0x3FFFFFFF;
88 assert(entry.best_match_length == best_match_length);
89 entry.matched = matched;
90 entry.valid = true;
91 }
92
93 const Sequence& src_;
94 const Sequence& dst_;
95
96 struct DiffMatchEntry {
DiffMatchEntryDiffMatchEntry97 DiffMatchEntry() : best_match_length(0), matched(false), valid(false) {}
98
99 uint32_t best_match_length : 30;
100 // Whether src[i] and dst[j] matched. This is an optimization to avoid
101 // calling the `match` function again when walking the LCS table.
102 uint32_t matched : 1;
103 // Use for the recursive algorithm to know if the contents of this entry are
104 // valid.
105 uint32_t valid : 1;
106 };
107
108 std::vector<std::vector<DiffMatchEntry>> table_;
109 };
110
111 template <typename Sequence>
112 template <typename T>
Get(std::function<bool (T src_elem,T dst_elem)> match,DiffMatch * src_match_result,DiffMatch * dst_match_result)113 uint32_t LongestCommonSubsequence<Sequence>::Get(
114 std::function<bool(T src_elem, T dst_elem)> match,
115 DiffMatch* src_match_result, DiffMatch* dst_match_result) {
116 CalculateLCS(match);
117 RetrieveMatch(src_match_result, dst_match_result);
118 return GetMemoizedLength({0, 0});
119 }
120
121 template <typename Sequence>
122 template <typename T>
CalculateLCS(std::function<bool (T src_elem,T dst_elem)> match)123 void LongestCommonSubsequence<Sequence>::CalculateLCS(
124 std::function<bool(T src_elem, T dst_elem)> match) {
125 // The LCS algorithm is simple. Given sequences s and d, with a:b depicting a
126 // range in python syntax:
127 //
128 // lcs(s[i:], d[j:]) =
129 // lcs(s[i+1:], d[j+1:]) + 1 if s[i] == d[j]
130 // max(lcs(s[i+1:], d[j:]), lcs(s[i:], d[j+1:])) o.w.
131 //
132 // Once the LCS table is filled according to the above, it can be walked and
133 // the best match retrieved.
134 //
135 // This is a recursive function with memoization, which avoids filling table
136 // entries where unnecessary. This makes the best case O(N) instead of
137 // O(N^2). The implemention uses a std::stack to avoid stack overflow on long
138 // sequences.
139
140 if (src_.empty() || dst_.empty()) {
141 return;
142 }
143
144 std::stack<DiffMatchIndex> to_calculate;
145 to_calculate.push({0, 0});
146
147 while (!to_calculate.empty()) {
148 DiffMatchIndex current = to_calculate.top();
149 to_calculate.pop();
150 assert(IsInBound(current));
151
152 // If already calculated through another path, ignore it.
153 if (IsCalculated(current)) {
154 continue;
155 }
156
157 if (match(src_[current.src_offset], dst_[current.dst_offset])) {
158 // If the current elements match, advance both indices and calculate the
159 // LCS if not already. Visit `current` again afterwards, so its
160 // corresponding entry will be updated.
161 DiffMatchIndex next = {current.src_offset + 1, current.dst_offset + 1};
162 if (IsCalculatedOrOutOfBound(next)) {
163 MarkMatched(current, GetMemoizedLength(next) + 1, true);
164 } else {
165 to_calculate.push(current);
166 to_calculate.push(next);
167 }
168 continue;
169 }
170
171 // We've reached a pair of elements that don't match. Calculate the LCS for
172 // both cases of either being left unmatched and take the max. Visit
173 // `current` again afterwards, so its corresponding entry will be updated.
174 DiffMatchIndex next_src = {current.src_offset + 1, current.dst_offset};
175 DiffMatchIndex next_dst = {current.src_offset, current.dst_offset + 1};
176
177 if (IsCalculatedOrOutOfBound(next_src) &&
178 IsCalculatedOrOutOfBound(next_dst)) {
179 uint32_t best_match_length =
180 std::max(GetMemoizedLength(next_src), GetMemoizedLength(next_dst));
181 MarkMatched(current, best_match_length, false);
182 continue;
183 }
184
185 to_calculate.push(current);
186 if (!IsCalculatedOrOutOfBound(next_src)) {
187 to_calculate.push(next_src);
188 }
189 if (!IsCalculatedOrOutOfBound(next_dst)) {
190 to_calculate.push(next_dst);
191 }
192 }
193 }
194
195 template <typename Sequence>
RetrieveMatch(DiffMatch * src_match_result,DiffMatch * dst_match_result)196 void LongestCommonSubsequence<Sequence>::RetrieveMatch(
197 DiffMatch* src_match_result, DiffMatch* dst_match_result) {
198 src_match_result->clear();
199 dst_match_result->clear();
200
201 src_match_result->resize(src_.size(), false);
202 dst_match_result->resize(dst_.size(), false);
203
204 DiffMatchIndex current = {0, 0};
205 while (IsInBound(current)) {
206 if (IsMatched(current)) {
207 (*src_match_result)[current.src_offset++] = true;
208 (*dst_match_result)[current.dst_offset++] = true;
209 continue;
210 }
211
212 if (GetMemoizedLength({current.src_offset + 1, current.dst_offset}) >=
213 GetMemoizedLength({current.src_offset, current.dst_offset + 1})) {
214 ++current.src_offset;
215 } else {
216 ++current.dst_offset;
217 }
218 }
219 }
220
221 } // namespace diff
222 } // namespace spvtools
223
224 #endif // SOURCE_DIFF_LCS_H_
225