1 #ifndef MARISA_GRIMOIRE_ALGORITHM_SORT_H_
2 #define MARISA_GRIMOIRE_ALGORITHM_SORT_H_
3
4 #include "marisa/base.h"
5
6 namespace marisa {
7 namespace grimoire {
8 namespace algorithm {
9 namespace details {
10
11 enum {
12 MARISA_INSERTION_SORT_THRESHOLD = 10
13 };
14
15 template <typename T>
get_label(const T & unit,std::size_t depth)16 int get_label(const T &unit, std::size_t depth) {
17 MARISA_DEBUG_IF(depth > unit.length(), MARISA_BOUND_ERROR);
18
19 return (depth < unit.length()) ? (int)(UInt8)unit[depth] : -1;
20 }
21
22 template <typename T>
median(const T & a,const T & b,const T & c,std::size_t depth)23 int median(const T &a, const T &b, const T &c, std::size_t depth) {
24 const int x = get_label(a, depth);
25 const int y = get_label(b, depth);
26 const int z = get_label(c, depth);
27 if (x < y) {
28 if (y < z) {
29 return y;
30 } else if (x < z) {
31 return z;
32 }
33 return x;
34 } else if (x < z) {
35 return x;
36 } else if (y < z) {
37 return z;
38 }
39 return y;
40 }
41
42 template <typename T>
compare(const T & lhs,const T & rhs,std::size_t depth)43 int compare(const T &lhs, const T &rhs, std::size_t depth) {
44 for (std::size_t i = depth; i < lhs.length(); ++i) {
45 if (i == rhs.length()) {
46 return 1;
47 }
48 if (lhs[i] != rhs[i]) {
49 return (UInt8)lhs[i] - (UInt8)rhs[i];
50 }
51 }
52 if (lhs.length() == rhs.length()) {
53 return 0;
54 }
55 return (lhs.length() < rhs.length()) ? -1 : 1;
56 }
57
58 template <typename Iterator>
insertion_sort(Iterator l,Iterator r,std::size_t depth)59 std::size_t insertion_sort(Iterator l, Iterator r, std::size_t depth) {
60 MARISA_DEBUG_IF(l > r, MARISA_BOUND_ERROR);
61
62 std::size_t count = 1;
63 for (Iterator i = l + 1; i < r; ++i) {
64 int result = 0;
65 for (Iterator j = i; j > l; --j) {
66 result = compare(*(j - 1), *j, depth);
67 if (result <= 0) {
68 break;
69 }
70 marisa::swap(*(j - 1), *j);
71 }
72 if (result != 0) {
73 ++count;
74 }
75 }
76 return count;
77 }
78
79 template <typename Iterator>
sort(Iterator l,Iterator r,std::size_t depth)80 std::size_t sort(Iterator l, Iterator r, std::size_t depth) {
81 MARISA_DEBUG_IF(l > r, MARISA_BOUND_ERROR);
82
83 std::size_t count = 0;
84 while ((r - l) > MARISA_INSERTION_SORT_THRESHOLD) {
85 Iterator pl = l;
86 Iterator pr = r;
87 Iterator pivot_l = l;
88 Iterator pivot_r = r;
89
90 const int pivot = median(*l, *(l + (r - l) / 2), *(r - 1), depth);
91 for ( ; ; ) {
92 while (pl < pr) {
93 const int label = get_label(*pl, depth);
94 if (label > pivot) {
95 break;
96 } else if (label == pivot) {
97 marisa::swap(*pl, *pivot_l);
98 ++pivot_l;
99 }
100 ++pl;
101 }
102 while (pl < pr) {
103 const int label = get_label(*--pr, depth);
104 if (label < pivot) {
105 break;
106 } else if (label == pivot) {
107 marisa::swap(*pr, *--pivot_r);
108 }
109 }
110 if (pl >= pr) {
111 break;
112 }
113 marisa::swap(*pl, *pr);
114 ++pl;
115 }
116 while (pivot_l > l) {
117 marisa::swap(*--pivot_l, *--pl);
118 }
119 while (pivot_r < r) {
120 marisa::swap(*pivot_r, *pr);
121 ++pivot_r;
122 ++pr;
123 }
124
125 if (((pl - l) > (pr - pl)) || ((r - pr) > (pr - pl))) {
126 if ((pr - pl) == 1) {
127 ++count;
128 } else if ((pr - pl) > 1) {
129 if (pivot == -1) {
130 ++count;
131 } else {
132 count += sort(pl, pr, depth + 1);
133 }
134 }
135
136 if ((pl - l) < (r - pr)) {
137 if ((pl - l) == 1) {
138 ++count;
139 } else if ((pl - l) > 1) {
140 count += sort(l, pl, depth);
141 }
142 l = pr;
143 } else {
144 if ((r - pr) == 1) {
145 ++count;
146 } else if ((r - pr) > 1) {
147 count += sort(pr, r, depth);
148 }
149 r = pl;
150 }
151 } else {
152 if ((pl - l) == 1) {
153 ++count;
154 } else if ((pl - l) > 1) {
155 count += sort(l, pl, depth);
156 }
157
158 if ((r - pr) == 1) {
159 ++count;
160 } else if ((r - pr) > 1) {
161 count += sort(pr, r, depth);
162 }
163
164 l = pl, r = pr;
165 if ((pr - pl) == 1) {
166 ++count;
167 } else if ((pr - pl) > 1) {
168 if (pivot == -1) {
169 l = r;
170 ++count;
171 } else {
172 ++depth;
173 }
174 }
175 }
176 }
177
178 if ((r - l) > 1) {
179 count += insertion_sort(l, r, depth);
180 }
181 return count;
182 }
183
184 } // namespace details
185
186 template <typename Iterator>
sort(Iterator begin,Iterator end)187 std::size_t sort(Iterator begin, Iterator end) {
188 MARISA_DEBUG_IF(begin > end, MARISA_BOUND_ERROR);
189 return details::sort(begin, end, 0);
190 };
191
192 } // namespace algorithm
193 } // namespace grimoire
194 } // namespace marisa
195
196 #endif // MARISA_GRIMOIRE_ALGORITHM_SORT_H_
197