1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_ALLOCATOR_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_ALLOCATOR_H_
18
19 #include <cstdlib>
20 #include <functional>
21 #include <memory>
22 #include <type_traits>
23 #include <utility>
24 #include "minddata/dataset/util/memory_pool.h"
25
26 namespace mindspore {
27 namespace dataset {
28 // The following conforms to the requirements of
29 // std::allocator. Do not rename/change any needed
30 // requirements, e.g. function names, typedef etc.
31 template <typename T>
32 class Allocator {
33 public:
34 template <typename U>
35 friend class Allocator;
36
37 using value_type = T;
38 using pointer = T *;
39 using const_pointer = const T *;
40 using reference = T &;
41 using const_reference = const T &;
42 using size_type = uint64_t;
43 using difference_type = std::ptrdiff_t;
44
45 template <typename U>
46 struct rebind {
47 using other = Allocator<U>;
48 };
49
50 using propagate_on_container_copy_assignment = std::true_type;
51 using propagate_on_container_move_assignment = std::true_type;
52 using propagate_on_container_swap = std::true_type;
53
Allocator(const std::shared_ptr<MemoryPool> & b)54 explicit Allocator(const std::shared_ptr<MemoryPool> &b) : pool_(b) {}
55
56 ~Allocator() = default;
57
58 template <typename U>
Allocator(Allocator<U> const & rhs)59 explicit Allocator(Allocator<U> const &rhs) : pool_(rhs.pool_) {}
60
61 template <typename U>
62 bool operator==(Allocator<U> const &rhs) const {
63 return pool_ == rhs.pool_;
64 }
65
66 template <typename U>
67 bool operator!=(Allocator<U> const &rhs) const {
68 return pool_ != rhs.pool_;
69 }
70
allocate(std::size_t n)71 pointer allocate(std::size_t n) {
72 void *p = nullptr;
73 Status rc = pool_->Allocate(n * sizeof(T), &p);
74 if (rc.IsOk()) {
75 return reinterpret_cast<pointer>(p);
76 } else if (rc == StatusCode::kMDOutOfMemory) {
77 throw std::bad_alloc();
78 } else {
79 throw std::exception();
80 }
81 }
82
83 void deallocate(pointer p, std::size_t n = 0) noexcept { pool_->Deallocate(p); }
84
max_size()85 size_type max_size() { return pool_->get_max_size(); }
86
87 private:
88 std::shared_ptr<MemoryPool> pool_;
89 };
90 /// \brief It is a wrapper of unique_ptr with a custom Allocator class defined above
91 template <typename T, typename C = std::allocator<T>, typename... Args>
MakeUnique(std::unique_ptr<T[],std::function<void (T *)>> * out,C alloc,size_t n,Args &&...args)92 Status MakeUnique(std::unique_ptr<T[], std::function<void(T *)>> *out, C alloc, size_t n, Args &&... args) {
93 RETURN_UNEXPECTED_IF_NULL(out);
94 CHECK_FAIL_RETURN_UNEXPECTED(n > 0, "size must be positive");
95 T *data = nullptr;
96 try {
97 data = alloc.allocate(n);
98 // Some of our implementation of allocator (e.g. NumaAllocator) don't throw std::bad_alloc.
99 // So we have to catch for null ptr
100 if (data == nullptr) {
101 return Status(StatusCode::kMDOutOfMemory);
102 }
103 if (!std::is_arithmetic<T>::value) {
104 for (size_t i = 0; i < n; i++) {
105 std::allocator_traits<C>::construct(alloc, &(data[i]), std::forward<Args>(args)...);
106 }
107 }
108 auto deleter = [](T *p, C f_alloc, size_t f_n) {
109 if (!std::is_arithmetic<T>::value && std::is_destructible<T>::value) {
110 for (size_t i = 0; i < f_n; ++i) {
111 std::allocator_traits<C>::destroy(f_alloc, &p[i]);
112 }
113 }
114 f_alloc.deallocate(p, f_n);
115 };
116 *out = std::unique_ptr<T[], std::function<void(T *)>>(data, std::bind(deleter, std::placeholders::_1, alloc, n));
117 } catch (const std::bad_alloc &e) {
118 if (data != nullptr) {
119 alloc.deallocate(data, n);
120 }
121 return Status(StatusCode::kMDOutOfMemory);
122 } catch (const std::exception &e) {
123 if (data != nullptr) {
124 alloc.deallocate(data, n);
125 }
126 RETURN_STATUS_UNEXPECTED(e.what());
127 }
128 return Status::OK();
129 }
130
131 /// \brief It is a wrapper of the above custom unique_ptr with some additional methods
132 /// \tparam T The type of object to be allocated
133 /// \tparam C Allocator. Default to std::allocator
134 template <typename T, typename C = std::allocator<T>>
135 class MemGuard {
136 public:
137 using allocator = C;
MemGuard()138 MemGuard() : n_(0) {}
MemGuard(allocator a)139 explicit MemGuard(allocator a) : n_(0), alloc_(a) {}
140 // There is no copy constructor nor assignment operator because the memory is solely owned by this object.
141 MemGuard(const MemGuard &) = delete;
142 MemGuard &operator=(const MemGuard &) = delete;
143 // On the other hand, We can support move constructor
MemGuard(MemGuard && lhs)144 MemGuard(MemGuard &&lhs) noexcept : n_(lhs.n_), alloc_(std::move(lhs.alloc_)), ptr_(std::move(lhs.ptr_)) {}
145 MemGuard &operator=(MemGuard &&lhs) noexcept {
146 if (this != &lhs) {
147 this->deallocate();
148 n_ = lhs.n_;
149 alloc_ = std::move(lhs.alloc_);
150 ptr_ = std::move(lhs.ptr_);
151 }
152 return *this;
153 }
154 /// \brief Explicitly deallocate the memory if allocated
deallocate()155 void deallocate() {
156 if (ptr_) {
157 ptr_.reset();
158 }
159 }
160 /// \brief Allocate memory (with emplace feature). Previous one will be released. If size is 0, no new memory is
161 /// allocated.
162 /// \param n Number of objects of type T to be allocated
163 /// \tparam Args Extra arguments pass to the constructor of T
164 template <typename... Args>
allocate(size_t n,Args &&...args)165 Status allocate(size_t n, Args &&... args) noexcept {
166 deallocate();
167 n_ = n;
168 return MakeUnique(&ptr_, alloc_, n, std::forward<Args>(args)...);
169 }
~MemGuard()170 ~MemGuard() noexcept { deallocate(); }
171 /// \brief Getter function
172 /// \return The pointer to the memory allocated
GetPointer()173 T *GetPointer() const { return ptr_.get(); }
174 /// \brief Getter function
175 /// \return The pointer to the memory allocated
GetMutablePointer()176 T *GetMutablePointer() { return ptr_.get(); }
177 /// \brief Overload [] operator to access a particular element
178 /// \param x index to the element. Must be less than number of element allocated.
179 /// \return pointer to the x-th element
180 T *operator[](size_t x) { return GetMutablePointer() + x; }
181 /// \brief Overload [] operator to access a particular element
182 /// \param x index to the element. Must be less than number of element allocated.
183 /// \return pointer to the x-th element
184 T *operator[](size_t x) const { return GetPointer() + x; }
185 /// \brief Return how many bytes are allocated in total
186 /// \return Number of bytes allocated in total
GetSizeInBytes()187 size_t GetSizeInBytes() const { return n_ * sizeof(T); }
188
189 private:
190 size_t n_;
191 allocator alloc_;
192 std::unique_ptr<T[], std::function<void(T *)>> ptr_;
193 };
194 } // namespace dataset
195 } // namespace mindspore
196
197 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_ALLOCATOR_H_
198