• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "minddata/dataset/util/circular_pool.h"
17 
18 #include "securec.h"
19 #include "minddata/dataset/util/log_adapter.h"
20 #include "minddata/dataset/util/system_pool.h"
21 
22 namespace mindspore {
23 namespace dataset {
AddOneArena()24 Status CircularPool::AddOneArena() {
25   Status rc;
26   std::shared_ptr<Arena> b;
27   RETURN_IF_NOT_OK(Arena::CreateArena(&b, arena_size_, is_cuda_malloc_));
28   tail_ = b.get();
29   cur_size_in_mb_ += arena_size_;
30   mem_segments_.push_back(std::move(b));
31   return Status::OK();
32 }
33 
Next()34 ListOfArenas::iterator CircularPool::CircularIterator::Next() {
35   ListOfArenas::iterator it = dp_->mem_segments_.begin();
36   uint32_t size = dp_->mem_segments_.size();
37   // This is what we return
38   it += cur_;
39   // Prepare for the next round
40   cur_++;
41   if (cur_ == size) {
42     if (start_ == 0) {
43       has_next_ = false;
44     } else {
45       wrap_ = true;
46       cur_ = 0;
47     }
48   } else if (cur_ == start_) {
49     has_next_ = false;
50   }
51   return it;
52 }
53 
has_next() const54 bool CircularPool::CircularIterator::has_next() const { return has_next_; }
55 
Reset()56 void CircularPool::CircularIterator::Reset() {
57   wrap_ = false;
58   has_next_ = false;
59   if (!dp_->mem_segments_.empty()) {
60     // Find the buddy arena that corresponds to the tail.
61     cur_tail_ = dp_->tail_;
62     auto list_end = dp_->mem_segments_.end();
63     auto it = std::find_if(dp_->mem_segments_.begin(), list_end,
64                            [this](const std::shared_ptr<Arena> &b) { return b.get() == cur_tail_; });
65     MS_ASSERT(it != list_end);
66     start_ = std::distance(dp_->mem_segments_.begin(), it);
67     cur_ = start_;
68     has_next_ = true;
69   }
70 }
71 
CircularIterator(CircularPool * dp)72 CircularPool::CircularIterator::CircularIterator(CircularPool *dp) : dp_(dp) { Reset(); }
73 
Allocate(size_t n,void ** p)74 Status CircularPool::Allocate(size_t n, void **p) {
75   if (p == nullptr) {
76     RETURN_STATUS_UNEXPECTED("p is null");
77   }
78   Status rc;
79   void *ptr = nullptr;
80   do {
81     SharedLock lock_s(&rw_lock_);
82     int prevSzInMB = cur_size_in_mb_;
83     bool move_tail = false;
84     CircularIterator cirIt(this);
85     while (cirIt.has_next()) {
86       auto it = cirIt.Next();
87       Arena *ba = it->get();
88       RETURN_UNEXPECTED_IF_NULL(ba);
89       if (ba->get_max_size() < n) {
90         RETURN_STATUS_OOM("Out of memory.");
91       }
92       // If we are asked to move forward the tail
93       if (move_tail) {
94         Arena *expected = cirIt.cur_tail_;
95         (void)atomic_compare_exchange_weak(&tail_, &expected, ba);
96         move_tail = false;
97       }
98       rc = ba->Allocate(n, &ptr);
99       if (rc.IsOk()) {
100         *p = ptr;
101         break;
102       } else if (rc == StatusCode::kMDOutOfMemory) {
103         // Make the next arena a new tail and continue.
104         move_tail = true;
105       } else {
106         return rc;
107       }
108     }
109 
110     // Handle the case we have done one round robin search.
111     if (ptr == nullptr) {
112       // If we have room to expand.
113       if (unlimited_ || cur_size_in_mb_ < max_size_in_mb_) {
114         // lock in exclusively mode.
115         lock_s.Upgrade();
116         // Check again if someone has already expanded.
117         if (cur_size_in_mb_ == prevSzInMB) {
118           RETURN_IF_NOT_OK(AddOneArena());
119         }
120         // Re-acquire the shared lock and try again
121         lock_s.Downgrade();
122       } else {
123         RETURN_STATUS_OOM("Out of memory.");
124       }
125     }
126   } while (ptr == nullptr);
127   return rc;
128 }
129 
Deallocate(void * p)130 void CircularPool::Deallocate(void *p) {
131   // Lock in the chain in shared mode and find out which
132   // segment it comes from
133   SharedLock lock(&rw_lock_);
134   auto it = std::find_if(mem_segments_.begin(), mem_segments_.end(), [this, p](std::shared_ptr<Arena> &b) -> bool {
135     char *q = reinterpret_cast<char *>(p);
136     auto *base = reinterpret_cast<const char *>(b->get_base_addr());
137     return (q > base && q < base + arena_size_ * 1048576L);
138   });
139   lock.Unlock();
140   MS_ASSERT(it != mem_segments_.end());
141   it->get()->Deallocate(p);
142 }
143 
Reallocate(void ** pp,size_t old_sz,size_t new_sz)144 Status CircularPool::Reallocate(void **pp, size_t old_sz, size_t new_sz) {
145   // Lock in the chain in shared mode and find out which
146   // segment it comes from
147   if (pp == nullptr) {
148     RETURN_STATUS_UNEXPECTED("pp is null");
149   }
150   void *p = *pp;
151   SharedLock lock(&rw_lock_);
152   auto it = std::find_if(mem_segments_.begin(), mem_segments_.end(), [this, p](std::shared_ptr<Arena> &b) -> bool {
153     char *q = reinterpret_cast<char *>(p);
154     auto *base = reinterpret_cast<const char *>(b->get_base_addr());
155     return (q > base && q < base + arena_size_ * 1048576L);
156   });
157   lock.Unlock();
158   MS_ASSERT(it != mem_segments_.end());
159   Arena *ba = it->get();
160   Status rc = ba->Reallocate(pp, old_sz, new_sz);
161   if (rc == StatusCode::kMDOutOfMemory) {
162     // The current arena has no room for the bigger size.
163     // Allocate free space from another arena and copy
164     // the content over.
165     void *q = nullptr;
166     rc = this->Allocate(new_sz, &q);
167     RETURN_IF_NOT_OK(rc);
168     errno_t err = memcpy_s(q, new_sz, p, old_sz);
169     if (err) {
170       this->Deallocate(q);
171       RETURN_STATUS_UNEXPECTED(std::to_string(err));
172     }
173     *pp = q;
174     ba->Deallocate(p);
175   }
176   return Status::OK();
177 }
178 
get_max_size() const179 uint64_t CircularPool::get_max_size() const { return mem_segments_.front()->get_max_size(); }
180 
PercentFree() const181 int CircularPool::PercentFree() const {
182   int percent_free = 0;
183   int num_arena = 0;
184   for (auto const &p : mem_segments_) {
185     percent_free += p->PercentFree();
186     num_arena++;
187   }
188   if (num_arena) {
189     return percent_free / num_arena;
190   } else {
191     return 100;
192   }
193 }
194 
CircularPool(int max_size_in_gb,int arena_size,bool is_cuda_malloc)195 CircularPool::CircularPool(int max_size_in_gb, int arena_size, bool is_cuda_malloc)
196     : unlimited_(max_size_in_gb <= 0),
197       max_size_in_mb_(unlimited_ ? std::numeric_limits<int32_t>::max() : max_size_in_gb * 1024),
198       arena_size_(arena_size),
199       is_cuda_malloc_(is_cuda_malloc),
200       cur_size_in_mb_(0) {}
201 
CreateCircularPool(std::shared_ptr<MemoryPool> * out_pool,int max_size_in_gb,int arena_size,bool createOneArena,bool is_cuda_malloc)202 Status CircularPool::CreateCircularPool(std::shared_ptr<MemoryPool> *out_pool, int max_size_in_gb, int arena_size,
203                                         bool createOneArena, bool is_cuda_malloc) {
204   Status rc;
205   if (out_pool == nullptr) {
206     RETURN_STATUS_UNEXPECTED("pPool is null");
207   }
208   auto pool = new (std::nothrow) CircularPool(max_size_in_gb, arena_size, is_cuda_malloc);
209   if (pool == nullptr) {
210     RETURN_STATUS_OOM("Out of memory.");
211   }
212   if (createOneArena) {
213     rc = pool->AddOneArena();
214   }
215   if (rc.IsOk()) {
216     (*out_pool).reset(pool);
217   } else {
218     delete pool;
219   }
220   return rc;
221 }
222 
223 CircularPool::~CircularPool() = default;
224 }  // namespace dataset
225 }  // namespace mindspore
226