1 /* 2 * Copyright (c) 2025 Huawei Device Co., Ltd. 3 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * you may not use this file except in compliance with the License. 5 * You may obtain a copy of the License at 6 * 7 * http://www.apache.org/licenses/LICENSE-2.0 8 * 9 * Unless required by applicable law or agreed to in writing, software 10 * distributed under the License is distributed on an "AS IS" BASIS, 11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 * See the License for the specific language governing permissions and 13 * limitations under the License. 14 */ 15 #ifndef RUNTIME_INCLUDE_TAIHE_MAP_HPP_ 16 #define RUNTIME_INCLUDE_TAIHE_MAP_HPP_ 17 // NOLINTBEGIN 18 19 #include <taihe/common.hpp> 20 21 #include <utility> 22 23 #define MAP_GROWTH_FACTOR 2 24 25 namespace taihe { 26 template <typename K, typename V> 27 struct map_view; 28 29 template <typename K, typename V> 30 struct map; 31 32 template <typename K, typename V> 33 struct map_view { 34 public: 35 using item_t = std::pair<K const, V>; 36 reservetaihe::map_view37 void reserve(std::size_t cap) const 38 { 39 if (cap == 0) { 40 return; 41 } 42 node_t **bucket = reinterpret_cast<node_t **>(calloc(cap, sizeof(node_t *))); 43 for (std::size_t i = 0; i < m_handle->cap; i++) { 44 node_t *current = m_handle->bucket[i]; 45 while (current) { 46 node_t *next = current->next; 47 std::size_t index = std::hash<K>()(current->item.first) % cap; 48 current->next = bucket[index]; 49 bucket[index] = current; 50 current = next; 51 } 52 } 53 free(m_handle->bucket); 54 m_handle->cap = cap; 55 m_handle->bucket = bucket; 56 } 57 sizetaihe::map_view58 std::size_t size() const noexcept 59 { 60 return m_handle->size; 61 } 62 emptytaihe::map_view63 bool empty() const noexcept 64 { 65 return m_handle->size == 0; 66 } 67 capacitytaihe::map_view68 std::size_t capacity() const noexcept 69 { 70 return m_handle->cap; 71 } 72 cleartaihe::map_view73 void clear() const 74 { 75 for (std::size_t i = 0; i < m_handle->cap; i++) { 76 while (m_handle->bucket[i]) { 77 node_t *next = m_handle->bucket[i]->next; 78 delete m_handle->bucket[i]; 79 m_handle->bucket[i] = next; 80 } 81 } 82 m_handle->size = 0; 83 } 84 85 template <bool cover = false, typename... Args> emplacetaihe::map_view86 std::pair<item_t *, bool> emplace(as_param_t<K> key, Args &&...args) const 87 { 88 std::size_t index = std::hash<K>()(key) % m_handle->cap; 89 node_t **current_ptr = &m_handle->bucket[index]; 90 while (*current_ptr) { 91 if ((*current_ptr)->item.first == key) { 92 if (cover) { 93 node_t *replaced = new node_t { 94 .item = {key, V {std::forward<Args>(args)...}}, 95 .next = (*current_ptr)->next, 96 }; 97 node_t *current = *current_ptr; 98 *current_ptr = replaced; 99 delete current; 100 } 101 return {&(*current_ptr)->item, false}; 102 } 103 current_ptr = &(*current_ptr)->next; 104 } 105 node_t *node = new node_t { 106 .item = {key, V {std::forward<Args>(args)...}}, 107 .next = m_handle->bucket[index], 108 }; 109 m_handle->bucket[index] = node; 110 m_handle->size++; 111 std::size_t required_cap = m_handle->size; 112 if (required_cap >= m_handle->cap) { 113 reserve(required_cap * MAP_GROWTH_FACTOR); 114 } 115 return {&node->item, true}; 116 } 117 find_itemtaihe::map_view118 item_t *find_item(as_param_t<K> key) const 119 { 120 std::size_t index = std::hash<K>()(key) % m_handle->cap; 121 node_t *current = m_handle->bucket[index]; 122 while (current) { 123 if (current->item.first == key) { 124 return ¤t->item; 125 } 126 current = current->next; 127 } 128 return nullptr; 129 } 130 131 // TODO: Change the return type to item_t * findtaihe::map_view132 V *find(as_param_t<K> key) const 133 { 134 item_t *item = find_item(key); 135 if (item) { 136 return &item->second; 137 } 138 return nullptr; 139 } 140 erasetaihe::map_view141 bool erase(as_param_t<K> key) const 142 { 143 std::size_t index = std::hash<K>()(key) % m_handle->cap; 144 node_t **current_ptr = &m_handle->bucket[index]; 145 while (*current_ptr) { 146 if ((*current_ptr)->item.first == key) { 147 node_t *current = *current_ptr; 148 *current_ptr = (*current_ptr)->next; 149 delete current; 150 m_handle->size--; 151 return true; 152 } 153 current_ptr = &(*current_ptr)->next; 154 } 155 return false; 156 } 157 158 struct node_t { 159 item_t item; 160 node_t *next; 161 }; 162 163 struct iterator { 164 using iterator_category = std::forward_iterator_tag; 165 using value_type = item_t; 166 using difference_type = std::ptrdiff_t; 167 using pointer = value_type *; 168 using reference = value_type &; 169 iteratortaihe::map_view::iterator170 iterator(node_t **bucket, node_t *current, std::size_t index, std::size_t cap) 171 : bucket(bucket), current(current), index(index), cap(cap) 172 { 173 } 174 operator *taihe::map_view::iterator175 reference operator*() const 176 { 177 return current->item; 178 } 179 operator ->taihe::map_view::iterator180 pointer operator->() const 181 { 182 return ¤t->item; 183 } 184 operator ++taihe::map_view::iterator185 iterator &operator++() 186 { 187 if (current->next) { 188 current = current->next; 189 } else { 190 ++index; 191 while (index < cap && !bucket[index]) { 192 ++index; 193 } 194 current = (index < cap) ? bucket[index] : nullptr; 195 } 196 return *this; 197 } 198 operator ++taihe::map_view::iterator199 iterator operator++(int) 200 { 201 iterator ocp = *this; 202 ++(*this); 203 return ocp; 204 } 205 operator ==taihe::map_view::iterator206 bool operator==(iterator const &other) const 207 { 208 return current == other.current; 209 } 210 operator !=taihe::map_view::iterator211 bool operator!=(iterator const &other) const 212 { 213 return !(*this == other); 214 } 215 216 private: 217 node_t **bucket; 218 node_t *current; 219 std::size_t index; 220 std::size_t cap; 221 }; 222 begintaihe::map_view223 iterator begin() const 224 { 225 std::size_t index = 0; 226 while (index < m_handle->cap && !m_handle->bucket[index]) { 227 ++index; 228 } 229 return iterator(m_handle->bucket, (index < m_handle->cap) ? m_handle->bucket[index] : nullptr, index, 230 m_handle->cap); 231 } 232 endtaihe::map_view233 iterator end() const 234 { 235 return iterator(m_handle->bucket, nullptr, m_handle->cap, m_handle->cap); 236 } 237 238 using const_iterator = iterator; 239 cbegintaihe::map_view240 const_iterator cbegin() const 241 { 242 return begin(); 243 } 244 cendtaihe::map_view245 const_iterator cend() const 246 { 247 return end(); 248 } 249 250 template <typename Visitor> accepttaihe::map_view251 void accept(Visitor &&visitor) const 252 { 253 for (std::size_t i = 0; i < m_handle->cap; i++) { 254 node_t *current = m_handle->bucket[i]; 255 while (current) { 256 visitor(current->item); 257 current = current->next; 258 } 259 } 260 } 261 262 private: 263 struct handle_t { 264 TRefCount count; 265 std::size_t cap; 266 node_t **bucket; 267 std::size_t size; 268 } *m_handle; 269 map_viewtaihe::map_view270 explicit map_view(handle_t *handle) : m_handle(handle) {} 271 272 friend struct map<K, V>; 273 274 friend struct std::hash<map<K, V>>; 275 operator ==(map_view lhs,map_view rhs)276 friend bool operator==(map_view lhs, map_view rhs) 277 { 278 return lhs.m_handle == rhs.m_handle; 279 } 280 }; 281 282 template <typename K, typename V> 283 struct map : map_view<K, V> { 284 using typename map_view<K, V>::node_t; 285 using typename map_view<K, V>::handle_t; 286 using map_view<K, V>::m_handle; 287 maptaihe::map288 explicit map(std::size_t cap = 16) : map(reinterpret_cast<handle_t *>(calloc(1, sizeof(handle_t)))) 289 { 290 node_t **bucket = reinterpret_cast<node_t **>(calloc(cap, sizeof(node_t *))); 291 tref_set(&m_handle->count, 1); 292 m_handle->cap = cap; 293 m_handle->bucket = bucket; 294 m_handle->size = 0; 295 } 296 maptaihe::map297 map(map<K, V> &&other) noexcept : map(other.m_handle) 298 { 299 other.m_handle = nullptr; 300 } 301 maptaihe::map302 map(map<K, V> const &other) : map(other.m_handle) 303 { 304 if (m_handle) { 305 tref_inc(&m_handle->count); 306 } 307 } 308 maptaihe::map309 map(map_view<K, V> const &other) : map(other.m_handle) 310 { 311 if (m_handle) { 312 tref_inc(&m_handle->count); 313 } 314 } 315 operator =taihe::map316 map &operator=(map other) 317 { 318 std::swap(this->m_handle, other.m_handle); 319 return *this; 320 } 321 ~maptaihe::map322 ~map() 323 { 324 if (m_handle && tref_dec(&m_handle->count)) { 325 this->clear(); 326 free(m_handle->bucket); 327 free(m_handle); 328 } 329 } 330 331 private: maptaihe::map332 explicit map(handle_t *handle) : map_view<K, V>(handle) {} 333 }; 334 335 template <typename K, typename V> 336 struct as_abi<map<K, V>> { 337 using type = void *; 338 }; 339 340 template <typename K, typename V> 341 struct as_abi<map_view<K, V>> { 342 using type = void *; 343 }; 344 345 template <typename K, typename V> 346 struct as_param<map<K, V>> { 347 using type = map_view<K, V>; 348 }; 349 } // namespace taihe 350 351 template <typename K, typename V> 352 struct std::hash<taihe::map<K, V>> { operator ()std::hash353 std::size_t operator()(taihe::map_view<K, V> val) const noexcept 354 { 355 return reinterpret_cast<std::size_t>(val.m_handle); 356 } 357 }; 358 359 #ifdef MAP_GROWTH_FACTOR 360 #undef MAP_GROWTH_FACTOR 361 #endif 362 // NOLINTEND 363 #endif // RUNTIME_INCLUDE_TAIHE_MAP_HPP_