• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Abstractions for processing small tasks in a batched fashion, to reduce
17 // processing times and costs that can be amortized across multiple tasks.
18 //
19 // The core class is BatchScheduler, which groups tasks into batches.
20 //
21 // BatchScheduler encapsulates logic for aggregating multiple tasks into a
22 // batch, and kicking off processing of a batch on a thread pool it manages.
23 //
24 // This file defines an abstract BatchScheduler class.
25 
26 #ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_H_
27 #define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_H_
28 
29 #include <stddef.h>
30 #include <algorithm>
31 #include <functional>
32 #include <memory>
33 #include <utility>
34 #include <vector>
35 
36 #include "tensorflow/core/lib/core/notification.h"
37 #include "tensorflow/core/lib/core/status.h"
38 #include "tensorflow/core/platform/logging.h"
39 #include "tensorflow/core/platform/macros.h"
40 #include "tensorflow/core/platform/mutex.h"
41 #include "tensorflow/core/platform/thread_annotations.h"
42 #include "tensorflow/core/platform/types.h"
43 
44 namespace tensorflow {
45 namespace serving {
46 
47 // The abstract superclass for a unit of work to be done as part of a batch.
48 //
49 // An implementing subclass typically contains (or points to):
50 //  (a) input data;
51 //  (b) a thread-safe completion signal (e.g. a Notification);
52 //  (c) a place to store the outcome (success, or some error), upon completion;
53 //  (d) a place to store the output data, upon success.
54 //
55 // Items (b), (c) and (d) are typically non-owned pointers to data homed
56 // elsewhere, because a task's ownership gets transferred to a BatchScheduler
57 // (see below) and it may be deleted as soon as it is done executing.
58 class BatchTask {
59  public:
60   virtual ~BatchTask() = default;
61 
62   // Returns the size of the task, in terms of how much it contributes to the
63   // size of a batch. (A batch's size is the sum of its task sizes.)
64   virtual size_t size() const = 0;
65 };
66 
67 // A thread-safe collection of BatchTasks, to be executed together in some
68 // fashion.
69 //
70 // At a given time, a batch is either "open" or "closed": an open batch can
71 // accept new tasks; a closed one cannot. A batch is monotonic: initially it is
72 // open and tasks can be added to it; then it is closed and its set of tasks
73 // remains fixed for the remainder of its life. A closed batch cannot be re-
74 // opened. Tasks can never be removed from a batch.
75 //
76 // Type parameter TaskType must be a subclass of BatchTask.
77 template <typename TaskType>
78 class Batch {
79  public:
80   Batch() = default;
81   virtual ~Batch();  // Blocks until the batch is closed.
82 
83   // Appends 'task' to the batch. After calling AddTask(), the newly-added task
84   // can be accessed via task(num_tasks()-1) or mutable_task(num_tasks()-1).
85   // Dies if the batch is closed.
86   void AddTask(std::unique_ptr<TaskType> task);
87 
88   // Removes the most recently added task. Returns nullptr if the batch is
89   // empty.
90   std::unique_ptr<TaskType> RemoveTask();
91 
92   // Returns the number of tasks in the batch.
93   int num_tasks() const;
94 
95   // Returns true iff the batch contains 0 tasks.
96   bool empty() const;
97 
98   // Returns a reference to the ith task (in terms of insertion order).
99   const TaskType& task(int i) const;
100 
101   // Returns a pointer to the ith task (in terms of insertion order).
102   TaskType* mutable_task(int i);
103 
104   // Returns the sum of the task sizes.
105   size_t size() const;
106 
107   // Returns true iff the batch is currently closed.
108   bool IsClosed() const;
109 
110   // Blocks until the batch is closed.
111   void WaitUntilClosed() const;
112 
113   // Marks the batch as closed. Dies if called more than once.
114   void Close();
115 
116  private:
117   mutable mutex mu_;
118 
119   // The tasks in the batch.
120   std::vector<std::unique_ptr<TaskType>> tasks_ GUARDED_BY(mu_);
121 
122   // The sum of the sizes of the tasks in 'tasks_'.
123   size_t size_ GUARDED_BY(mu_) = 0;
124 
125   // Whether the batch has been closed.
126   Notification closed_;
127 
128   TF_DISALLOW_COPY_AND_ASSIGN(Batch);
129 };
130 
131 // An abstract batch scheduler class. Collects individual tasks into batches,
132 // and processes each batch on a pool of "batch threads" that it manages. The
133 // actual logic for processing a batch is accomplished via a callback.
134 //
135 // Type parameter TaskType must be a subclass of BatchTask.
136 template <typename TaskType>
137 class BatchScheduler {
138  public:
139   virtual ~BatchScheduler() = default;
140 
141   // Submits a task to be processed as part of a batch.
142   //
143   // Ownership of '*task' is transferred to the callee iff the method returns
144   // Status::OK. In that case, '*task' is left as nullptr. Otherwise, '*task' is
145   // left as-is.
146   //
147   // If no batch processing capacity is available to process this task at the
148   // present time, and any task queue maintained by the implementing subclass is
149   // full, this method returns an UNAVAILABLE error code. The client may retry
150   // later.
151   //
152   // Other problems, such as the task size being larger than the maximum batch
153   // size, yield other, permanent error types.
154   //
155   // In all cases, this method returns "quickly" without blocking for any
156   // substantial amount of time. If the method returns Status::OK, the task is
157   // processed asynchronously, and any errors that occur during the processing
158   // of the batch that includes the task can be reported to 'task'.
159   virtual Status Schedule(std::unique_ptr<TaskType>* task) = 0;
160 
161   // Returns the number of tasks that have been scheduled (i.e. accepted by
162   // Schedule()), but have yet to be handed to a thread for execution as part of
163   // a batch. Note that this returns the number of tasks, not the aggregate task
164   // size (so if there is one task of size 3 and one task of size 5, this method
165   // returns 2 rather than 8).
166   virtual size_t NumEnqueuedTasks() const = 0;
167 
168   // Returns a guaranteed number of size 1 tasks that can be Schedule()d without
169   // getting an UNAVAILABLE error. In a typical implementation, returns the
170   // available space on a queue.
171   //
172   // There are two important caveats:
173   //  1. The guarantee does not extend to varying-size tasks due to possible
174   //     internal fragmentation of batches.
175   //  2. The guarantee only holds in a single-thread environment or critical
176   //     section, i.e. if an intervening thread cannot call Schedule().
177   //
178   // This method is useful for monitoring, or for guaranteeing a future slot in
179   // the schedule (but being mindful about the caveats listed above).
180   virtual size_t SchedulingCapacity() const = 0;
181 
182   // Returns the maximum allowed size of tasks submitted to the scheduler. (This
183   // is typically equal to a configured maximum batch size.)
184   virtual size_t max_task_size() const = 0;
185 };
186 
187 //////////
188 // Implementation details follow. API users need not read.
189 
190 template <typename TaskType>
~Batch()191 Batch<TaskType>::~Batch() {
192   WaitUntilClosed();
193 }
194 
195 template <typename TaskType>
AddTask(std::unique_ptr<TaskType> task)196 void Batch<TaskType>::AddTask(std::unique_ptr<TaskType> task) {
197   DCHECK(!IsClosed());
198   {
199     mutex_lock l(mu_);
200     size_ += task->size();
201     tasks_.push_back(std::move(task));
202   }
203 }
204 
205 template <typename TaskType>
RemoveTask()206 std::unique_ptr<TaskType> Batch<TaskType>::RemoveTask() {
207   {
208     mutex_lock l(mu_);
209     if (tasks_.empty()) {
210       return nullptr;
211     }
212     std::unique_ptr<TaskType> task = std::move(tasks_.back());
213     size_ -= task->size();
214     tasks_.pop_back();
215     return task;
216   }
217 }
218 
219 template <typename TaskType>
num_tasks()220 int Batch<TaskType>::num_tasks() const {
221   {
222     mutex_lock l(mu_);
223     return tasks_.size();
224   }
225 }
226 
227 template <typename TaskType>
empty()228 bool Batch<TaskType>::empty() const {
229   {
230     mutex_lock l(mu_);
231     return tasks_.empty();
232   }
233 }
234 
235 template <typename TaskType>
task(int i)236 const TaskType& Batch<TaskType>::task(int i) const {
237   DCHECK_GE(i, 0);
238   {
239     mutex_lock l(mu_);
240     DCHECK_LT(i, tasks_.size());
241     return *tasks_[i].get();
242   }
243 }
244 
245 template <typename TaskType>
mutable_task(int i)246 TaskType* Batch<TaskType>::mutable_task(int i) {
247   DCHECK_GE(i, 0);
248   {
249     mutex_lock l(mu_);
250     DCHECK_LT(i, tasks_.size());
251     return tasks_[i].get();
252   }
253 }
254 
255 template <typename TaskType>
size()256 size_t Batch<TaskType>::size() const {
257   {
258     mutex_lock l(mu_);
259     return size_;
260   }
261 }
262 
263 template <typename TaskType>
IsClosed()264 bool Batch<TaskType>::IsClosed() const {
265   return const_cast<Notification*>(&closed_)->HasBeenNotified();
266 }
267 
268 template <typename TaskType>
WaitUntilClosed()269 void Batch<TaskType>::WaitUntilClosed() const {
270   const_cast<Notification*>(&closed_)->WaitForNotification();
271 }
272 
273 template <typename TaskType>
Close()274 void Batch<TaskType>::Close() {
275   closed_.Notify();
276 }
277 
278 }  // namespace serving
279 }  // namespace tensorflow
280 
281 #endif  // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_H_
282