• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #ifndef MARISA_GRIMOIRE_ALGORITHM_SORT_H_
2 #define MARISA_GRIMOIRE_ALGORITHM_SORT_H_
3 
4 #include "marisa/base.h"
5 
6 namespace marisa {
7 namespace grimoire {
8 namespace algorithm {
9 namespace details {
10 
11 enum {
12   MARISA_INSERTION_SORT_THRESHOLD = 10
13 };
14 
15 template <typename T>
get_label(const T & unit,std::size_t depth)16 int get_label(const T &unit, std::size_t depth) {
17   MARISA_DEBUG_IF(depth > unit.length(), MARISA_BOUND_ERROR);
18 
19   return (depth < unit.length()) ? (int)(UInt8)unit[depth] : -1;
20 }
21 
22 template <typename T>
median(const T & a,const T & b,const T & c,std::size_t depth)23 int median(const T &a, const T &b, const T &c, std::size_t depth) {
24   const int x = get_label(a, depth);
25   const int y = get_label(b, depth);
26   const int z = get_label(c, depth);
27   if (x < y) {
28     if (y < z) {
29       return y;
30     } else if (x < z) {
31       return z;
32     }
33     return x;
34   } else if (x < z) {
35     return x;
36   } else if (y < z) {
37     return z;
38   }
39   return y;
40 }
41 
42 template <typename T>
compare(const T & lhs,const T & rhs,std::size_t depth)43 int compare(const T &lhs, const T &rhs, std::size_t depth) {
44   for (std::size_t i = depth; i < lhs.length(); ++i) {
45     if (i == rhs.length()) {
46       return 1;
47     }
48     if (lhs[i] != rhs[i]) {
49       return (UInt8)lhs[i] - (UInt8)rhs[i];
50     }
51   }
52   if (lhs.length() == rhs.length()) {
53     return 0;
54   }
55   return (lhs.length() < rhs.length()) ? -1 : 1;
56 }
57 
58 template <typename Iterator>
insertion_sort(Iterator l,Iterator r,std::size_t depth)59 std::size_t insertion_sort(Iterator l, Iterator r, std::size_t depth) {
60   MARISA_DEBUG_IF(l > r, MARISA_BOUND_ERROR);
61 
62   std::size_t count = 1;
63   for (Iterator i = l + 1; i < r; ++i) {
64     int result = 0;
65     for (Iterator j = i; j > l; --j) {
66       result = compare(*(j - 1), *j, depth);
67       if (result <= 0) {
68         break;
69       }
70       marisa::swap(*(j - 1), *j);
71     }
72     if (result != 0) {
73       ++count;
74     }
75   }
76   return count;
77 }
78 
79 template <typename Iterator>
sort(Iterator l,Iterator r,std::size_t depth)80 std::size_t sort(Iterator l, Iterator r, std::size_t depth) {
81   MARISA_DEBUG_IF(l > r, MARISA_BOUND_ERROR);
82 
83   std::size_t count = 0;
84   while ((r - l) > MARISA_INSERTION_SORT_THRESHOLD) {
85     Iterator pl = l;
86     Iterator pr = r;
87     Iterator pivot_l = l;
88     Iterator pivot_r = r;
89 
90     const int pivot = median(*l, *(l + (r - l) / 2), *(r - 1), depth);
91     for ( ; ; ) {
92       while (pl < pr) {
93         const int label = get_label(*pl, depth);
94         if (label > pivot) {
95           break;
96         } else if (label == pivot) {
97           marisa::swap(*pl, *pivot_l);
98           ++pivot_l;
99         }
100         ++pl;
101       }
102       while (pl < pr) {
103         const int label = get_label(*--pr, depth);
104         if (label < pivot) {
105           break;
106         } else if (label == pivot) {
107           marisa::swap(*pr, *--pivot_r);
108         }
109       }
110       if (pl >= pr) {
111         break;
112       }
113       marisa::swap(*pl, *pr);
114       ++pl;
115     }
116     while (pivot_l > l) {
117       marisa::swap(*--pivot_l, *--pl);
118     }
119     while (pivot_r < r) {
120       marisa::swap(*pivot_r, *pr);
121       ++pivot_r;
122       ++pr;
123     }
124 
125     if (((pl - l) > (pr - pl)) || ((r - pr) > (pr - pl))) {
126       if ((pr - pl) == 1) {
127         ++count;
128       } else if ((pr - pl) > 1) {
129         if (pivot == -1) {
130           ++count;
131         } else {
132           count += sort(pl, pr, depth + 1);
133         }
134       }
135 
136       if ((pl - l) < (r - pr)) {
137         if ((pl - l) == 1) {
138           ++count;
139         } else if ((pl - l) > 1) {
140           count += sort(l, pl, depth);
141         }
142         l = pr;
143       } else {
144         if ((r - pr) == 1) {
145           ++count;
146         } else if ((r - pr) > 1) {
147           count += sort(pr, r, depth);
148         }
149         r = pl;
150       }
151     } else {
152       if ((pl - l) == 1) {
153         ++count;
154       } else if ((pl - l) > 1) {
155         count += sort(l, pl, depth);
156       }
157 
158       if ((r - pr) == 1) {
159         ++count;
160       } else if ((r - pr) > 1) {
161         count += sort(pr, r, depth);
162       }
163 
164       l = pl, r = pr;
165       if ((pr - pl) == 1) {
166         ++count;
167       } else if ((pr - pl) > 1) {
168         if (pivot == -1) {
169           l = r;
170           ++count;
171         } else {
172           ++depth;
173         }
174       }
175     }
176   }
177 
178   if ((r - l) > 1) {
179     count += insertion_sort(l, r, depth);
180   }
181   return count;
182 }
183 
184 }  // namespace details
185 
186 template <typename Iterator>
sort(Iterator begin,Iterator end)187 std::size_t sort(Iterator begin, Iterator end) {
188   MARISA_DEBUG_IF(begin > end, MARISA_BOUND_ERROR);
189   return details::sort(begin, end, 0);
190 };
191 
192 }  // namespace algorithm
193 }  // namespace grimoire
194 }  // namespace marisa
195 
196 #endif  // MARISA_GRIMOIRE_ALGORITHM_SORT_H_
197