1 /*
2 * Copyright (c) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #ifndef TASK_FACTORY_HPP
17 #define TASK_FACTORY_HPP
18
19 #include "tm/task_base.h"
20 #include "util/cb_func.h"
21 #include "util/slab.h"
22
23 namespace ffrt {
24 template <typename T>
25 class TaskFactory {
26 public:
27 static TaskFactory<T>& Instance();
28
Alloc()29 static T* Alloc()
30 {
31 return Instance().alloc_();
32 }
33
Free(T * task)34 static void Free(T* task)
35 {
36 Instance().free_(task);
37 }
38
Free_(T * task)39 static void Free_(T* task)
40 {
41 if (Instance().free__ != nullptr) {
42 Instance().free__(task);
43 }
44 }
45
GetUnfreedMem()46 static std::vector<void*> GetUnfreedMem()
47 {
48 if (Instance().getUnfreedMem_ != nullptr) {
49 return Instance().getUnfreedMem_();
50 }
51 return {};
52 }
53
GetUnfreedMemSize()54 static std::size_t GetUnfreedMemSize()
55 {
56 if (Instance().getUnfreedMemSize_ != nullptr) {
57 return Instance().getUnfreedMemSize_();
58 }
59 return 0;
60 }
61
GetUnfreedTasksFiltered()62 static std::vector<void*> GetUnfreedTasksFiltered()
63 {
64 LockMem();
65 std::vector<void*> unfreed = GetUnfreedMem();
66 // Filter out tasks where the reference count increment failed.
67 unfreed.erase(
68 std::remove_if(unfreed.begin(), unfreed.end(),
69 [](void* task) {
70 return !IncDeleteRefIfPositive(reinterpret_cast<TaskBase*>(task));
71 }),
72 unfreed.end()
73 );
74 UnlockMem();
75 return unfreed;
76 }
77
HasBeenFreed(T * task)78 static bool HasBeenFreed(T* task)
79 {
80 if (Instance().hasBeenFreed_ != nullptr) {
81 return Instance().hasBeenFreed_(task);
82 }
83 return true;
84 }
85
LockMem()86 static void LockMem()
87 {
88 if (Instance().lockMem_ != nullptr) {
89 Instance().lockMem_();
90 }
91 }
92
UnlockMem()93 static void UnlockMem()
94 {
95 if (Instance().unlockMem_ != nullptr) {
96 Instance().unlockMem_();
97 }
98 }
99
100 static void RegistCb(
101 typename TaskAllocCB<T>::Alloc &&alloc,
102 typename TaskAllocCB<T>::Free &&free,
103 typename TaskAllocCB<T>::Free_ &&free_ = nullptr,
104 typename TaskAllocCB<T>::GetUnfreedMem &&getUnfreedMem = nullptr,
105 typename TaskAllocCB<T>::GetUnfreedMemSize &&getUnfreedMemSize = nullptr,
106 typename TaskAllocCB<T>::HasBeenFreed &&hasBeenFreed = nullptr,
107 typename TaskAllocCB<T>::LockMem &&lockMem = nullptr,
108 typename TaskAllocCB<T>::UnlockMem &&unlockMem = nullptr)
109 {
110 Instance().alloc_ = std::move(alloc);
111 Instance().free_ = std::move(free);
112 Instance().free__ = std::move(free_);
113 Instance().getUnfreedMem_ = std::move(getUnfreedMem);
114 Instance().getUnfreedMemSize_ = std::move(getUnfreedMemSize);
115 Instance().hasBeenFreed_ = std::move(hasBeenFreed);
116 Instance().lockMem_ = std::move(lockMem);
117 Instance().unlockMem_ = std::move(unlockMem);
118 }
119
120 private:
121 typename TaskAllocCB<T>::Alloc alloc_;
122 typename TaskAllocCB<T>::Free free_;
123 typename TaskAllocCB<T>::Free_ free__;
124 typename TaskAllocCB<T>::GetUnfreedMem getUnfreedMem_;
125 typename TaskAllocCB<T>::GetUnfreedMemSize getUnfreedMemSize_;
126 typename TaskAllocCB<T>::HasBeenFreed hasBeenFreed_;
127 typename TaskAllocCB<T>::LockMem lockMem_;
128 typename TaskAllocCB<T>::UnlockMem unlockMem_;
129 };
130
131 template <typename T>
132 class TaskMemScopedLock {
133 public:
TaskMemScopedLock()134 TaskMemScopedLock()
135 {
136 TaskFactory<T>::LockMem();
137 }
138
~TaskMemScopedLock()139 ~TaskMemScopedLock()
140 {
141 TaskFactory<T>::UnlockMem();
142 }
143 };
144
145 template <typename FactoryTaskType, typename AllocatorTaskType = FactoryTaskType>
RegisterTaskFactoryCallbacks()146 void RegisterTaskFactoryCallbacks()
147 {
148 ffrt::TaskFactory<FactoryTaskType>::RegistCb(
149 [] () -> FactoryTaskType* {
150 return ffrt::SimpleAllocator<AllocatorTaskType>::AllocMem();
151 },
152 [] (FactoryTaskType* task) {
153 ffrt::SimpleAllocator<AllocatorTaskType>::FreeMem(static_cast<AllocatorTaskType*>(task));
154 },
155 [] (FactoryTaskType* task) {
156 ffrt::SimpleAllocator<AllocatorTaskType>::FreeMem_(static_cast<AllocatorTaskType*>(task));
157 },
158 ffrt::SimpleAllocator<AllocatorTaskType>::getUnfreedMem,
159 ffrt::SimpleAllocator<AllocatorTaskType>::getUnfreedMemSize,
160 [] (FactoryTaskType* task) {
161 return ffrt::SimpleAllocator<AllocatorTaskType>::HasBeenFreed(static_cast<AllocatorTaskType*>(task));
162 },
163 ffrt::SimpleAllocator<AllocatorTaskType>::LockMem,
164 ffrt::SimpleAllocator<AllocatorTaskType>::UnlockMem);
165 }
166 } // namespace ffrt
167
168 #endif