1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Abstractions for processing small tasks in a batched fashion, to reduce
17 // processing times and costs that can be amortized across multiple tasks.
18 //
19 // The core class is BatchScheduler, which groups tasks into batches.
20 //
21 // BatchScheduler encapsulates logic for aggregating multiple tasks into a
22 // batch, and kicking off processing of a batch on a thread pool it manages.
23 //
24 // This file defines an abstract BatchScheduler class.
25
26 #ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_H_
27 #define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_H_
28
29 #include <stddef.h>
30 #include <algorithm>
31 #include <functional>
32 #include <memory>
33 #include <utility>
34 #include <vector>
35
36 #include "tensorflow/core/lib/core/notification.h"
37 #include "tensorflow/core/lib/core/status.h"
38 #include "tensorflow/core/platform/logging.h"
39 #include "tensorflow/core/platform/macros.h"
40 #include "tensorflow/core/platform/mutex.h"
41 #include "tensorflow/core/platform/thread_annotations.h"
42 #include "tensorflow/core/platform/types.h"
43
44 namespace tensorflow {
45 namespace serving {
46
47 // The abstract superclass for a unit of work to be done as part of a batch.
48 //
49 // An implementing subclass typically contains (or points to):
50 // (a) input data;
51 // (b) a thread-safe completion signal (e.g. a Notification);
52 // (c) a place to store the outcome (success, or some error), upon completion;
53 // (d) a place to store the output data, upon success.
54 //
55 // Items (b), (c) and (d) are typically non-owned pointers to data homed
56 // elsewhere, because a task's ownership gets transferred to a BatchScheduler
57 // (see below) and it may be deleted as soon as it is done executing.
58 class BatchTask {
59 public:
60 virtual ~BatchTask() = default;
61
62 // Returns the size of the task, in terms of how much it contributes to the
63 // size of a batch. (A batch's size is the sum of its task sizes.)
64 virtual size_t size() const = 0;
65 };
66
67 // A thread-safe collection of BatchTasks, to be executed together in some
68 // fashion.
69 //
70 // At a given time, a batch is either "open" or "closed": an open batch can
71 // accept new tasks; a closed one cannot. A batch is monotonic: initially it is
72 // open and tasks can be added to it; then it is closed and its set of tasks
73 // remains fixed for the remainder of its life. A closed batch cannot be re-
74 // opened. Tasks can never be removed from a batch.
75 //
76 // Type parameter TaskType must be a subclass of BatchTask.
77 template <typename TaskType>
78 class Batch {
79 public:
80 Batch() = default;
81 virtual ~Batch(); // Blocks until the batch is closed.
82
83 // Appends 'task' to the batch. After calling AddTask(), the newly-added task
84 // can be accessed via task(num_tasks()-1) or mutable_task(num_tasks()-1).
85 // Dies if the batch is closed.
86 void AddTask(std::unique_ptr<TaskType> task);
87
88 // Removes the most recently added task. Returns nullptr if the batch is
89 // empty.
90 std::unique_ptr<TaskType> RemoveTask();
91
92 // Returns the number of tasks in the batch.
93 int num_tasks() const;
94
95 // Returns true iff the batch contains 0 tasks.
96 bool empty() const;
97
98 // Returns a reference to the ith task (in terms of insertion order).
99 const TaskType& task(int i) const;
100
101 // Returns a pointer to the ith task (in terms of insertion order).
102 TaskType* mutable_task(int i);
103
104 // Returns the sum of the task sizes.
105 size_t size() const;
106
107 // Returns true iff the batch is currently closed.
108 bool IsClosed() const;
109
110 // Blocks until the batch is closed.
111 void WaitUntilClosed() const;
112
113 // Marks the batch as closed. Dies if called more than once.
114 void Close();
115
116 private:
117 mutable mutex mu_;
118
119 // The tasks in the batch.
120 std::vector<std::unique_ptr<TaskType>> tasks_ GUARDED_BY(mu_);
121
122 // The sum of the sizes of the tasks in 'tasks_'.
123 size_t size_ GUARDED_BY(mu_) = 0;
124
125 // Whether the batch has been closed.
126 Notification closed_;
127
128 TF_DISALLOW_COPY_AND_ASSIGN(Batch);
129 };
130
131 // An abstract batch scheduler class. Collects individual tasks into batches,
132 // and processes each batch on a pool of "batch threads" that it manages. The
133 // actual logic for processing a batch is accomplished via a callback.
134 //
135 // Type parameter TaskType must be a subclass of BatchTask.
136 template <typename TaskType>
137 class BatchScheduler {
138 public:
139 virtual ~BatchScheduler() = default;
140
141 // Submits a task to be processed as part of a batch.
142 //
143 // Ownership of '*task' is transferred to the callee iff the method returns
144 // Status::OK. In that case, '*task' is left as nullptr. Otherwise, '*task' is
145 // left as-is.
146 //
147 // If no batch processing capacity is available to process this task at the
148 // present time, and any task queue maintained by the implementing subclass is
149 // full, this method returns an UNAVAILABLE error code. The client may retry
150 // later.
151 //
152 // Other problems, such as the task size being larger than the maximum batch
153 // size, yield other, permanent error types.
154 //
155 // In all cases, this method returns "quickly" without blocking for any
156 // substantial amount of time. If the method returns Status::OK, the task is
157 // processed asynchronously, and any errors that occur during the processing
158 // of the batch that includes the task can be reported to 'task'.
159 virtual Status Schedule(std::unique_ptr<TaskType>* task) = 0;
160
161 // Returns the number of tasks that have been scheduled (i.e. accepted by
162 // Schedule()), but have yet to be handed to a thread for execution as part of
163 // a batch. Note that this returns the number of tasks, not the aggregate task
164 // size (so if there is one task of size 3 and one task of size 5, this method
165 // returns 2 rather than 8).
166 virtual size_t NumEnqueuedTasks() const = 0;
167
168 // Returns a guaranteed number of size 1 tasks that can be Schedule()d without
169 // getting an UNAVAILABLE error. In a typical implementation, returns the
170 // available space on a queue.
171 //
172 // There are two important caveats:
173 // 1. The guarantee does not extend to varying-size tasks due to possible
174 // internal fragmentation of batches.
175 // 2. The guarantee only holds in a single-thread environment or critical
176 // section, i.e. if an intervening thread cannot call Schedule().
177 //
178 // This method is useful for monitoring, or for guaranteeing a future slot in
179 // the schedule (but being mindful about the caveats listed above).
180 virtual size_t SchedulingCapacity() const = 0;
181
182 // Returns the maximum allowed size of tasks submitted to the scheduler. (This
183 // is typically equal to a configured maximum batch size.)
184 virtual size_t max_task_size() const = 0;
185 };
186
187 //////////
188 // Implementation details follow. API users need not read.
189
190 template <typename TaskType>
~Batch()191 Batch<TaskType>::~Batch() {
192 WaitUntilClosed();
193 }
194
195 template <typename TaskType>
AddTask(std::unique_ptr<TaskType> task)196 void Batch<TaskType>::AddTask(std::unique_ptr<TaskType> task) {
197 DCHECK(!IsClosed());
198 {
199 mutex_lock l(mu_);
200 size_ += task->size();
201 tasks_.push_back(std::move(task));
202 }
203 }
204
205 template <typename TaskType>
RemoveTask()206 std::unique_ptr<TaskType> Batch<TaskType>::RemoveTask() {
207 {
208 mutex_lock l(mu_);
209 if (tasks_.empty()) {
210 return nullptr;
211 }
212 std::unique_ptr<TaskType> task = std::move(tasks_.back());
213 size_ -= task->size();
214 tasks_.pop_back();
215 return task;
216 }
217 }
218
219 template <typename TaskType>
num_tasks()220 int Batch<TaskType>::num_tasks() const {
221 {
222 mutex_lock l(mu_);
223 return tasks_.size();
224 }
225 }
226
227 template <typename TaskType>
empty()228 bool Batch<TaskType>::empty() const {
229 {
230 mutex_lock l(mu_);
231 return tasks_.empty();
232 }
233 }
234
235 template <typename TaskType>
task(int i)236 const TaskType& Batch<TaskType>::task(int i) const {
237 DCHECK_GE(i, 0);
238 {
239 mutex_lock l(mu_);
240 DCHECK_LT(i, tasks_.size());
241 return *tasks_[i].get();
242 }
243 }
244
245 template <typename TaskType>
mutable_task(int i)246 TaskType* Batch<TaskType>::mutable_task(int i) {
247 DCHECK_GE(i, 0);
248 {
249 mutex_lock l(mu_);
250 DCHECK_LT(i, tasks_.size());
251 return tasks_[i].get();
252 }
253 }
254
255 template <typename TaskType>
size()256 size_t Batch<TaskType>::size() const {
257 {
258 mutex_lock l(mu_);
259 return size_;
260 }
261 }
262
263 template <typename TaskType>
IsClosed()264 bool Batch<TaskType>::IsClosed() const {
265 return const_cast<Notification*>(&closed_)->HasBeenNotified();
266 }
267
268 template <typename TaskType>
WaitUntilClosed()269 void Batch<TaskType>::WaitUntilClosed() const {
270 const_cast<Notification*>(&closed_)->WaitForNotification();
271 }
272
273 template <typename TaskType>
Close()274 void Batch<TaskType>::Close() {
275 closed_.Notify();
276 }
277
278 } // namespace serving
279 } // namespace tensorflow
280
281 #endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_H_
282