1 /* 2 * Copyright (c) 2025 Huawei Device Co., Ltd. 3 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * you may not use this file except in compliance with the License. 5 * You may obtain a copy of the License at 6 * 7 * http://www.apache.org/licenses/LICENSE-2.0 8 * 9 * Unless required by applicable law or agreed to in writing, software 10 * distributed under the License is distributed on an "AS IS" BASIS, 11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 * See the License for the specific language governing permissions and 13 * limitations under the License. 14 */ 15 #ifndef RUNTIME_INCLUDE_TAIHE_SET_HPP_ 16 #define RUNTIME_INCLUDE_TAIHE_SET_HPP_ 17 // NOLINTBEGIN 18 19 #include <taihe/common.hpp> 20 21 #include <utility> 22 23 #define SET_GROWTH_FACTOR 2 24 25 namespace taihe { 26 template <typename K> 27 struct set_view; 28 29 template <typename K> 30 struct set; 31 32 template <typename K> 33 struct set_view { 34 public: 35 using item_t = K const; 36 reservetaihe::set_view37 void reserve(std::size_t cap) const 38 { 39 if (cap == 0) { 40 return; 41 } 42 node_t **bucket = reinterpret_cast<node_t **>(calloc(cap, sizeof(node_t *))); 43 for (std::size_t i = 0; i < m_handle->cap; i++) { 44 node_t *current = m_handle->bucket[i]; 45 while (current) { 46 node_t *next = current->next; 47 std::size_t index = std::hash<K>()(current->item) % cap; 48 current->next = bucket[index]; 49 bucket[index] = current; 50 current = next; 51 } 52 } 53 free(m_handle->bucket); 54 m_handle->cap = cap; 55 m_handle->bucket = bucket; 56 } 57 sizetaihe::set_view58 std::size_t size() const noexcept 59 { 60 return m_handle->size; 61 } 62 emptytaihe::set_view63 bool empty() const noexcept 64 { 65 return m_handle->size == 0; 66 } 67 capacitytaihe::set_view68 std::size_t capacity() const noexcept 69 { 70 return m_handle->cap; 71 } 72 cleartaihe::set_view73 void clear() const 74 { 75 for (std::size_t i = 0; i < m_handle->cap; i++) { 76 while (m_handle->bucket[i]) { 77 node_t *next = m_handle->bucket[i]->next; 78 delete m_handle->bucket[i]; 79 m_handle->bucket[i] = next; 80 } 81 } 82 m_handle->size = 0; 83 } 84 85 template <bool cover = false> emplacetaihe::set_view86 std::pair<item_t *, bool> emplace(as_param_t<K> key) const 87 { 88 std::size_t index = std::hash<K>()(key) % m_handle->cap; 89 node_t **current_ptr = &m_handle->bucket[index]; 90 while (*current_ptr) { 91 if ((*current_ptr)->item == key) { 92 if (cover) { 93 node_t *replaced = new node_t { 94 .item = key, 95 .next = (*current_ptr)->next, 96 }; 97 node_t *current = *current_ptr; 98 *current_ptr = replaced; 99 delete current; 100 } 101 return {&(*current_ptr)->item, false}; 102 } 103 current_ptr = &(*current_ptr)->next; 104 } 105 node_t *node = new node_t { 106 .item = key, 107 .next = m_handle->bucket[index], 108 }; 109 m_handle->bucket[index] = node; 110 m_handle->size++; 111 std::size_t required_cap = m_handle->size; 112 if (required_cap >= m_handle->cap) { 113 reserve(required_cap * SET_GROWTH_FACTOR); 114 } 115 return {&node->item, true}; 116 } 117 find_itemtaihe::set_view118 item_t *find_item(as_param_t<K> key) const 119 { 120 std::size_t index = std::hash<K>()(key) % m_handle->cap; 121 node_t *current = m_handle->bucket[index]; 122 while (current) { 123 if (current->item == key) { 124 return ¤t->item; 125 } 126 current = current->next; 127 } 128 return nullptr; 129 } 130 131 // TODO: Change the return type to item_t * findtaihe::set_view132 bool find(as_param_t<K> key) const 133 { 134 item_t *item = find_item(key); 135 if (item) { 136 return true; 137 } 138 return false; 139 } 140 erasetaihe::set_view141 bool erase(as_param_t<K> key) const 142 { 143 std::size_t index = std::hash<K>()(key) % m_handle->cap; 144 node_t **current_ptr = &m_handle->bucket[index]; 145 while (*current_ptr) { 146 if ((*current_ptr)->item == key) { 147 node_t *current = *current_ptr; 148 *current_ptr = (*current_ptr)->next; 149 delete current; 150 m_handle->size--; 151 return true; 152 } 153 current_ptr = &(*current_ptr)->next; 154 } 155 return false; 156 } 157 158 struct node_t { 159 item_t item; 160 node_t *next; 161 }; 162 163 struct iterator { 164 using iterator_category = std::forward_iterator_tag; 165 using value_type = item_t; 166 using difference_type = std::ptrdiff_t; 167 using pointer = value_type *; 168 using reference = value_type &; 169 iteratortaihe::set_view::iterator170 iterator(node_t **bucket, node_t *current, std::size_t index, std::size_t cap) 171 : bucket(bucket), current(current), index(index), cap(cap) 172 { 173 } 174 operator *taihe::set_view::iterator175 reference operator*() const 176 { 177 return current->item; 178 } 179 operator ->taihe::set_view::iterator180 pointer operator->() const 181 { 182 return ¤t->item; 183 } 184 operator ++taihe::set_view::iterator185 iterator &operator++() 186 { 187 if (current->next) { 188 current = current->next; 189 } else { 190 ++index; 191 while (index < cap && !bucket[index]) { 192 ++index; 193 } 194 current = (index < cap) ? bucket[index] : nullptr; 195 } 196 return *this; 197 } 198 operator ++taihe::set_view::iterator199 iterator operator++(int) 200 { 201 iterator ocp = *this; 202 ++(*this); 203 return ocp; 204 } 205 operator ==taihe::set_view::iterator206 bool operator==(iterator const &other) const 207 { 208 return current == other.current; 209 } 210 operator !=taihe::set_view::iterator211 bool operator!=(iterator const &other) const 212 { 213 return !(*this == other); 214 } 215 216 private: 217 node_t **bucket; 218 node_t *current; 219 std::size_t index; 220 std::size_t cap; 221 }; 222 begintaihe::set_view223 iterator begin() const 224 { 225 std::size_t index = 0; 226 while (index < m_handle->cap && !m_handle->bucket[index]) { 227 ++index; 228 } 229 return iterator(m_handle->bucket, (index < m_handle->cap) ? m_handle->bucket[index] : nullptr, index, 230 m_handle->cap); 231 } 232 endtaihe::set_view233 iterator end() const 234 { 235 return iterator(m_handle->bucket, nullptr, m_handle->cap, m_handle->cap); 236 } 237 238 using const_iterator = iterator; 239 cbegintaihe::set_view240 const_iterator cbegin() const 241 { 242 return begin(); 243 } 244 cendtaihe::set_view245 const_iterator cend() const 246 { 247 return end(); 248 } 249 250 template <typename Visitor> accepttaihe::set_view251 void accept(Visitor &&visitor) const 252 { 253 for (std::size_t i = 0; i < m_handle->cap; i++) { 254 node_t *current = m_handle->bucket[i]; 255 while (current) { 256 visitor(current->item); 257 current = current->next; 258 } 259 } 260 } 261 262 private: 263 struct handle_t { 264 TRefCount count; 265 std::size_t cap; 266 node_t **bucket; 267 std::size_t size; 268 } *m_handle; 269 set_viewtaihe::set_view270 explicit set_view(handle_t *handle) : m_handle(handle) {} 271 272 friend struct set<K>; 273 274 friend struct std::hash<set<K>>; 275 operator ==(set_view lhs,set_view rhs)276 friend bool operator==(set_view lhs, set_view rhs) 277 { 278 return lhs.m_handle == rhs.m_handle; 279 } 280 }; 281 282 template <typename K> 283 struct set : set_view<K> { 284 using typename set_view<K>::node_t; 285 using typename set_view<K>::handle_t; 286 using set_view<K>::m_handle; 287 settaihe::set288 explicit set(std::size_t cap = 16) : set(reinterpret_cast<handle_t *>(calloc(1, sizeof(handle_t)))) 289 { 290 node_t **bucket = reinterpret_cast<node_t **>(calloc(cap, sizeof(node_t *))); 291 tref_set(&m_handle->count, 1); 292 m_handle->cap = cap; 293 m_handle->bucket = bucket; 294 m_handle->size = 0; 295 } 296 settaihe::set297 set(set<K> &&other) noexcept : set(other.m_handle) 298 { 299 other.m_handle = nullptr; 300 } 301 settaihe::set302 set(set<K> const &other) : set(other.m_handle) 303 { 304 if (m_handle) { 305 tref_inc(&m_handle->count); 306 } 307 } 308 settaihe::set309 set(set_view<K> const &other) : set(other.m_handle) 310 { 311 if (m_handle) { 312 tref_inc(&m_handle->count); 313 } 314 } 315 operator =taihe::set316 set &operator=(set other) 317 { 318 std::swap(this->m_handle, other.m_handle); 319 return *this; 320 } 321 ~settaihe::set322 ~set() 323 { 324 if (m_handle && tref_dec(&m_handle->count)) { 325 this->clear(); 326 free(m_handle->bucket); 327 free(m_handle); 328 } 329 } 330 331 private: settaihe::set332 explicit set(handle_t *handle) : set_view<K>(handle) {} 333 }; 334 335 template <typename K> 336 struct as_abi<set<K>> { 337 using type = void *; 338 }; 339 340 template <typename K> 341 struct as_abi<set_view<K>> { 342 using type = void *; 343 }; 344 345 template <typename K> 346 struct as_param<set<K>> { 347 using type = set_view<K>; 348 }; 349 } // namespace taihe 350 351 template <typename K> 352 struct std::hash<taihe::set<K>> { operator ()std::hash353 std::size_t operator()(taihe::set_view<K> val) const noexcept 354 { 355 return reinterpret_cast<std::size_t>(val.m_handle); 356 } 357 }; 358 359 #ifdef SET_GROWTH_FACTOR 360 #undef SET_GROWTH_FACTOR 361 #endif 362 // NOLINTEND 363 #endif // RUNTIME_INCLUDE_TAIHE_SET_HPP_