• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2022 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef SOURCE_DIFF_LCS_H_
16 #define SOURCE_DIFF_LCS_H_
17 
18 #include <algorithm>
19 #include <cassert>
20 #include <cstddef>
21 #include <cstdint>
22 #include <functional>
23 #include <stack>
24 #include <vector>
25 
26 namespace spvtools {
27 namespace diff {
28 
29 // The result of a diff.
30 using DiffMatch = std::vector<bool>;
31 
32 // Helper class to find the longest common subsequence between two function
33 // bodies.
34 template <typename Sequence>
35 class LongestCommonSubsequence {
36  public:
LongestCommonSubsequence(const Sequence & src,const Sequence & dst)37   LongestCommonSubsequence(const Sequence& src, const Sequence& dst)
38       : src_(src),
39         dst_(dst),
40         table_(src.size(), std::vector<DiffMatchEntry>(dst.size())) {}
41 
42   // Given two sequences, it creates a matching between them.  The elements are
43   // simply marked as matched in src and dst, with any unmatched element in src
44   // implying a removal and any unmatched element in dst implying an addition.
45   //
46   // Returns the length of the longest common subsequence.
47   template <typename T>
48   uint32_t Get(std::function<bool(T src_elem, T dst_elem)> match,
49                DiffMatch* src_match_result, DiffMatch* dst_match_result);
50 
51  private:
52   struct DiffMatchIndex {
53     uint32_t src_offset;
54     uint32_t dst_offset;
55   };
56 
57   template <typename T>
58   void CalculateLCS(std::function<bool(T src_elem, T dst_elem)> match);
59   void RetrieveMatch(DiffMatch* src_match_result, DiffMatch* dst_match_result);
IsInBound(DiffMatchIndex index)60   bool IsInBound(DiffMatchIndex index) {
61     return index.src_offset < src_.size() && index.dst_offset < dst_.size();
62   }
IsCalculated(DiffMatchIndex index)63   bool IsCalculated(DiffMatchIndex index) {
64     assert(IsInBound(index));
65     return table_[index.src_offset][index.dst_offset].valid;
66   }
IsCalculatedOrOutOfBound(DiffMatchIndex index)67   bool IsCalculatedOrOutOfBound(DiffMatchIndex index) {
68     return !IsInBound(index) || IsCalculated(index);
69   }
GetMemoizedLength(DiffMatchIndex index)70   uint32_t GetMemoizedLength(DiffMatchIndex index) {
71     if (!IsInBound(index)) {
72       return 0;
73     }
74     assert(IsCalculated(index));
75     return table_[index.src_offset][index.dst_offset].best_match_length;
76   }
IsMatched(DiffMatchIndex index)77   bool IsMatched(DiffMatchIndex index) {
78     assert(IsCalculated(index));
79     return table_[index.src_offset][index.dst_offset].matched;
80   }
MarkMatched(DiffMatchIndex index,uint32_t best_match_length,bool matched)81   void MarkMatched(DiffMatchIndex index, uint32_t best_match_length,
82                    bool matched) {
83     assert(IsInBound(index));
84     DiffMatchEntry& entry = table_[index.src_offset][index.dst_offset];
85     assert(!entry.valid);
86 
87     entry.best_match_length = best_match_length & 0x3FFFFFFF;
88     assert(entry.best_match_length == best_match_length);
89     entry.matched = matched;
90     entry.valid = true;
91   }
92 
93   const Sequence& src_;
94   const Sequence& dst_;
95 
96   struct DiffMatchEntry {
DiffMatchEntryDiffMatchEntry97     DiffMatchEntry() : best_match_length(0), matched(false), valid(false) {}
98 
99     uint32_t best_match_length : 30;
100     // Whether src[i] and dst[j] matched.  This is an optimization to avoid
101     // calling the `match` function again when walking the LCS table.
102     uint32_t matched : 1;
103     // Use for the recursive algorithm to know if the contents of this entry are
104     // valid.
105     uint32_t valid : 1;
106   };
107 
108   std::vector<std::vector<DiffMatchEntry>> table_;
109 };
110 
111 template <typename Sequence>
112 template <typename T>
Get(std::function<bool (T src_elem,T dst_elem)> match,DiffMatch * src_match_result,DiffMatch * dst_match_result)113 uint32_t LongestCommonSubsequence<Sequence>::Get(
114     std::function<bool(T src_elem, T dst_elem)> match,
115     DiffMatch* src_match_result, DiffMatch* dst_match_result) {
116   CalculateLCS(match);
117   RetrieveMatch(src_match_result, dst_match_result);
118   return GetMemoizedLength({0, 0});
119 }
120 
121 template <typename Sequence>
122 template <typename T>
CalculateLCS(std::function<bool (T src_elem,T dst_elem)> match)123 void LongestCommonSubsequence<Sequence>::CalculateLCS(
124     std::function<bool(T src_elem, T dst_elem)> match) {
125   // The LCS algorithm is simple.  Given sequences s and d, with a:b depicting a
126   // range in python syntax:
127   //
128   //     lcs(s[i:], d[j:]) =
129   //         lcs(s[i+1:], d[j+1:]) + 1                        if s[i] == d[j]
130   //         max(lcs(s[i+1:], d[j:]), lcs(s[i:], d[j+1:]))               o.w.
131   //
132   // Once the LCS table is filled according to the above, it can be walked and
133   // the best match retrieved.
134   //
135   // This is a recursive function with memoization, which avoids filling table
136   // entries where unnecessary.  This makes the best case O(N) instead of
137   // O(N^2).  The implemention uses a std::stack to avoid stack overflow on long
138   // sequences.
139 
140   if (src_.empty() || dst_.empty()) {
141     return;
142   }
143 
144   std::stack<DiffMatchIndex> to_calculate;
145   to_calculate.push({0, 0});
146 
147   while (!to_calculate.empty()) {
148     DiffMatchIndex current = to_calculate.top();
149     to_calculate.pop();
150     assert(IsInBound(current));
151 
152     // If already calculated through another path, ignore it.
153     if (IsCalculated(current)) {
154       continue;
155     }
156 
157     if (match(src_[current.src_offset], dst_[current.dst_offset])) {
158       // If the current elements match, advance both indices and calculate the
159       // LCS if not already.  Visit `current` again afterwards, so its
160       // corresponding entry will be updated.
161       DiffMatchIndex next = {current.src_offset + 1, current.dst_offset + 1};
162       if (IsCalculatedOrOutOfBound(next)) {
163         MarkMatched(current, GetMemoizedLength(next) + 1, true);
164       } else {
165         to_calculate.push(current);
166         to_calculate.push(next);
167       }
168       continue;
169     }
170 
171     // We've reached a pair of elements that don't match.  Calculate the LCS for
172     // both cases of either being left unmatched and take the max.  Visit
173     // `current` again afterwards, so its corresponding entry will be updated.
174     DiffMatchIndex next_src = {current.src_offset + 1, current.dst_offset};
175     DiffMatchIndex next_dst = {current.src_offset, current.dst_offset + 1};
176 
177     if (IsCalculatedOrOutOfBound(next_src) &&
178         IsCalculatedOrOutOfBound(next_dst)) {
179       uint32_t best_match_length =
180           std::max(GetMemoizedLength(next_src), GetMemoizedLength(next_dst));
181       MarkMatched(current, best_match_length, false);
182       continue;
183     }
184 
185     to_calculate.push(current);
186     if (!IsCalculatedOrOutOfBound(next_src)) {
187       to_calculate.push(next_src);
188     }
189     if (!IsCalculatedOrOutOfBound(next_dst)) {
190       to_calculate.push(next_dst);
191     }
192   }
193 }
194 
195 template <typename Sequence>
RetrieveMatch(DiffMatch * src_match_result,DiffMatch * dst_match_result)196 void LongestCommonSubsequence<Sequence>::RetrieveMatch(
197     DiffMatch* src_match_result, DiffMatch* dst_match_result) {
198   src_match_result->clear();
199   dst_match_result->clear();
200 
201   src_match_result->resize(src_.size(), false);
202   dst_match_result->resize(dst_.size(), false);
203 
204   DiffMatchIndex current = {0, 0};
205   while (IsInBound(current)) {
206     if (IsMatched(current)) {
207       (*src_match_result)[current.src_offset++] = true;
208       (*dst_match_result)[current.dst_offset++] = true;
209       continue;
210     }
211 
212     if (GetMemoizedLength({current.src_offset + 1, current.dst_offset}) >=
213         GetMemoizedLength({current.src_offset, current.dst_offset + 1})) {
214       ++current.src_offset;
215     } else {
216       ++current.dst_offset;
217     }
218   }
219 }
220 
221 }  // namespace diff
222 }  // namespace spvtools
223 
224 #endif  // SOURCE_DIFF_LCS_H_
225