• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2025 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  * http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 #ifndef RUNTIME_INCLUDE_TAIHE_SET_HPP_
16 #define RUNTIME_INCLUDE_TAIHE_SET_HPP_
17 // NOLINTBEGIN
18 
19 #include <taihe/common.hpp>
20 
21 #include <utility>
22 
23 #define SET_GROWTH_FACTOR 2
24 
25 namespace taihe {
26 template <typename K>
27 struct set_view;
28 
29 template <typename K>
30 struct set;
31 
32 template <typename K>
33 struct set_view {
34 public:
35     using item_t = K const;
36 
reservetaihe::set_view37     void reserve(std::size_t cap) const
38     {
39         if (cap == 0) {
40             return;
41         }
42         node_t **bucket = reinterpret_cast<node_t **>(calloc(cap, sizeof(node_t *)));
43         for (std::size_t i = 0; i < m_handle->cap; i++) {
44             node_t *current = m_handle->bucket[i];
45             while (current) {
46                 node_t *next = current->next;
47                 std::size_t index = std::hash<K>()(current->item) % cap;
48                 current->next = bucket[index];
49                 bucket[index] = current;
50                 current = next;
51             }
52         }
53         free(m_handle->bucket);
54         m_handle->cap = cap;
55         m_handle->bucket = bucket;
56     }
57 
sizetaihe::set_view58     std::size_t size() const noexcept
59     {
60         return m_handle->size;
61     }
62 
emptytaihe::set_view63     bool empty() const noexcept
64     {
65         return m_handle->size == 0;
66     }
67 
capacitytaihe::set_view68     std::size_t capacity() const noexcept
69     {
70         return m_handle->cap;
71     }
72 
cleartaihe::set_view73     void clear() const
74     {
75         for (std::size_t i = 0; i < m_handle->cap; i++) {
76             while (m_handle->bucket[i]) {
77                 node_t *next = m_handle->bucket[i]->next;
78                 delete m_handle->bucket[i];
79                 m_handle->bucket[i] = next;
80             }
81         }
82         m_handle->size = 0;
83     }
84 
85     template <bool cover = false>
emplacetaihe::set_view86     std::pair<item_t *, bool> emplace(as_param_t<K> key) const
87     {
88         std::size_t index = std::hash<K>()(key) % m_handle->cap;
89         node_t **current_ptr = &m_handle->bucket[index];
90         while (*current_ptr) {
91             if ((*current_ptr)->item == key) {
92                 if (cover) {
93                     node_t *replaced = new node_t {
94                         .item = key,
95                         .next = (*current_ptr)->next,
96                     };
97                     node_t *current = *current_ptr;
98                     *current_ptr = replaced;
99                     delete current;
100                 }
101                 return {&(*current_ptr)->item, false};
102             }
103             current_ptr = &(*current_ptr)->next;
104         }
105         node_t *node = new node_t {
106             .item = key,
107             .next = m_handle->bucket[index],
108         };
109         m_handle->bucket[index] = node;
110         m_handle->size++;
111         std::size_t required_cap = m_handle->size;
112         if (required_cap >= m_handle->cap) {
113             reserve(required_cap * SET_GROWTH_FACTOR);
114         }
115         return {&node->item, true};
116     }
117 
find_itemtaihe::set_view118     item_t *find_item(as_param_t<K> key) const
119     {
120         std::size_t index = std::hash<K>()(key) % m_handle->cap;
121         node_t *current = m_handle->bucket[index];
122         while (current) {
123             if (current->item == key) {
124                 return &current->item;
125             }
126             current = current->next;
127         }
128         return nullptr;
129     }
130 
131     // TODO: Change the return type to item_t *
findtaihe::set_view132     bool find(as_param_t<K> key) const
133     {
134         item_t *item = find_item(key);
135         if (item) {
136             return true;
137         }
138         return false;
139     }
140 
erasetaihe::set_view141     bool erase(as_param_t<K> key) const
142     {
143         std::size_t index = std::hash<K>()(key) % m_handle->cap;
144         node_t **current_ptr = &m_handle->bucket[index];
145         while (*current_ptr) {
146             if ((*current_ptr)->item == key) {
147                 node_t *current = *current_ptr;
148                 *current_ptr = (*current_ptr)->next;
149                 delete current;
150                 m_handle->size--;
151                 return true;
152             }
153             current_ptr = &(*current_ptr)->next;
154         }
155         return false;
156     }
157 
158     struct node_t {
159         item_t item;
160         node_t *next;
161     };
162 
163     struct iterator {
164         using iterator_category = std::forward_iterator_tag;
165         using value_type = item_t;
166         using difference_type = std::ptrdiff_t;
167         using pointer = value_type *;
168         using reference = value_type &;
169 
iteratortaihe::set_view::iterator170         iterator(node_t **bucket, node_t *current, std::size_t index, std::size_t cap)
171             : bucket(bucket), current(current), index(index), cap(cap)
172         {
173         }
174 
operator *taihe::set_view::iterator175         reference operator*() const
176         {
177             return current->item;
178         }
179 
operator ->taihe::set_view::iterator180         pointer operator->() const
181         {
182             return &current->item;
183         }
184 
operator ++taihe::set_view::iterator185         iterator &operator++()
186         {
187             if (current->next) {
188                 current = current->next;
189             } else {
190                 ++index;
191                 while (index < cap && !bucket[index]) {
192                     ++index;
193                 }
194                 current = (index < cap) ? bucket[index] : nullptr;
195             }
196             return *this;
197         }
198 
operator ++taihe::set_view::iterator199         iterator operator++(int)
200         {
201             iterator ocp = *this;
202             ++(*this);
203             return ocp;
204         }
205 
operator ==taihe::set_view::iterator206         bool operator==(iterator const &other) const
207         {
208             return current == other.current;
209         }
210 
operator !=taihe::set_view::iterator211         bool operator!=(iterator const &other) const
212         {
213             return !(*this == other);
214         }
215 
216     private:
217         node_t **bucket;
218         node_t *current;
219         std::size_t index;
220         std::size_t cap;
221     };
222 
begintaihe::set_view223     iterator begin() const
224     {
225         std::size_t index = 0;
226         while (index < m_handle->cap && !m_handle->bucket[index]) {
227             ++index;
228         }
229         return iterator(m_handle->bucket, (index < m_handle->cap) ? m_handle->bucket[index] : nullptr, index,
230                         m_handle->cap);
231     }
232 
endtaihe::set_view233     iterator end() const
234     {
235         return iterator(m_handle->bucket, nullptr, m_handle->cap, m_handle->cap);
236     }
237 
238     using const_iterator = iterator;
239 
cbegintaihe::set_view240     const_iterator cbegin() const
241     {
242         return begin();
243     }
244 
cendtaihe::set_view245     const_iterator cend() const
246     {
247         return end();
248     }
249 
250     template <typename Visitor>
accepttaihe::set_view251     void accept(Visitor &&visitor) const
252     {
253         for (std::size_t i = 0; i < m_handle->cap; i++) {
254             node_t *current = m_handle->bucket[i];
255             while (current) {
256                 visitor(current->item);
257                 current = current->next;
258             }
259         }
260     }
261 
262 private:
263     struct handle_t {
264         TRefCount count;
265         std::size_t cap;
266         node_t **bucket;
267         std::size_t size;
268     } *m_handle;
269 
set_viewtaihe::set_view270     explicit set_view(handle_t *handle) : m_handle(handle) {}
271 
272     friend struct set<K>;
273 
274     friend struct std::hash<set<K>>;
275 
operator ==(set_view lhs,set_view rhs)276     friend bool operator==(set_view lhs, set_view rhs)
277     {
278         return lhs.m_handle == rhs.m_handle;
279     }
280 };
281 
282 template <typename K>
283 struct set : set_view<K> {
284     using typename set_view<K>::node_t;
285     using typename set_view<K>::handle_t;
286     using set_view<K>::m_handle;
287 
settaihe::set288     explicit set(std::size_t cap = 16) : set(reinterpret_cast<handle_t *>(calloc(1, sizeof(handle_t))))
289     {
290         node_t **bucket = reinterpret_cast<node_t **>(calloc(cap, sizeof(node_t *)));
291         tref_set(&m_handle->count, 1);
292         m_handle->cap = cap;
293         m_handle->bucket = bucket;
294         m_handle->size = 0;
295     }
296 
settaihe::set297     set(set<K> &&other) noexcept : set(other.m_handle)
298     {
299         other.m_handle = nullptr;
300     }
301 
settaihe::set302     set(set<K> const &other) : set(other.m_handle)
303     {
304         if (m_handle) {
305             tref_inc(&m_handle->count);
306         }
307     }
308 
settaihe::set309     set(set_view<K> const &other) : set(other.m_handle)
310     {
311         if (m_handle) {
312             tref_inc(&m_handle->count);
313         }
314     }
315 
operator =taihe::set316     set &operator=(set other)
317     {
318         std::swap(this->m_handle, other.m_handle);
319         return *this;
320     }
321 
~settaihe::set322     ~set()
323     {
324         if (m_handle && tref_dec(&m_handle->count)) {
325             this->clear();
326             free(m_handle->bucket);
327             free(m_handle);
328         }
329     }
330 
331 private:
settaihe::set332     explicit set(handle_t *handle) : set_view<K>(handle) {}
333 };
334 
335 template <typename K>
336 struct as_abi<set<K>> {
337     using type = void *;
338 };
339 
340 template <typename K>
341 struct as_abi<set_view<K>> {
342     using type = void *;
343 };
344 
345 template <typename K>
346 struct as_param<set<K>> {
347     using type = set_view<K>;
348 };
349 }  // namespace taihe
350 
351 template <typename K>
352 struct std::hash<taihe::set<K>> {
operator ()std::hash353     std::size_t operator()(taihe::set_view<K> val) const noexcept
354     {
355         return reinterpret_cast<std::size_t>(val.m_handle);
356     }
357 };
358 
359 #ifdef SET_GROWTH_FACTOR
360 #undef SET_GROWTH_FACTOR
361 #endif
362 // NOLINTEND
363 #endif  // RUNTIME_INCLUDE_TAIHE_SET_HPP_