• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_INPUT_TASK_H_
17 #define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_INPUT_TASK_H_
18 
19 #include <algorithm>
20 #include <atomic>
21 #include <functional>
22 #include <memory>
23 #include <utility>
24 
25 #include "absl/base/call_once.h"
26 #include "absl/synchronization/mutex.h"
27 #include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
28 #include "tensorflow/core/kernels/batching_util/concat_split_util.h"
29 #include "tensorflow/core/platform/errors.h"
30 #include "tensorflow/core/platform/thread_annotations.h"
31 #include "tensorflow/core/util/incremental_barrier.h"
32 
33 namespace tensorflow {
34 namespace serving {
35 
36 namespace internal {
37 template <typename TaskType>
38 class BatchInputTaskHandleTestAccess;
39 
40 template <typename TaskType>
41 class BatchInputTaskTestAccess;
42 }  // namespace internal
43 
44 template <typename TaskType>
45 class BatchInputTask;
46 
47 // A RAII-style object that holds a ref-counted batch-input-task, and
48 // represents a slice of batch-input-task.
49 
50 // To be handed out to callers of `BatchInputTask::ToTaskHandles` quickly
51 // (i.e. not necessarily waiting for input split)
52 //
53 // `BatchInputTaskHandle::GetSplitTask` evaluates to the slice of task.
54 template <typename TaskType>
55 class BatchInputTaskHandle : public BatchTask {
56  public:
57   BatchInputTaskHandle(
58       std::shared_ptr<BatchInputTask<TaskType>> batch_input_task, int split_id,
59       size_t task_size);
60 
61   // Should be called once. Returns nullptr on subsequent calls.
62   std::unique_ptr<TaskType> GetSplitTask();
63 
64   // Returns the size of this task.
size()65   size_t size() const override { return task_size_; }
66 
67  private:
68   template <typename T>
69   friend class internal::BatchInputTaskHandleTestAccess;
70 
split_id()71   int split_id() const { return split_id_; }
72 
73   std::shared_ptr<BatchInputTask<TaskType>> batch_input_task_;
74 
75   // The handle evaluates to the N-th slice of original task, and
76   // N is `split_id_`.
77   const int split_id_;
78 
79   const size_t task_size_;
80 
81   std::atomic<bool> once_{false};
82 };
83 
84 // BatchInputTask encapsulates a input (`input_task`) to be batched and the
85 // information to get task splits after it's enqueued, so as to support lazy
86 // split of a task.
87 //
88 // Input split could reduce excessive padding for efficiency; lazy split
89 // moves task-split out of the critical path of enqueue and dequeue and reduces
90 // contention.
91 //
92 // BatchInputTask is thread safe.
93 //
94 // Usage
95 //
96 // ... a deque with frequent enqueue and dequeue operations ...
97 // ... Note, a deque of Batch of BatchInputTaskHandle is used to form batches
98 //     at enqueue time (split is lazy at deque time);
99 // ... For use cases to form batches at dequeue time, we can use a deque of
100 //     BatchInputTaskHandle directly, and "peek" metadata to form a batch by
101 //     then.
102 // std::deque<std::unique_ptr<Batch<BatchInputTaskHandle<TaskType>>>> deque_
103 //     TF_GUARDED_BY(mu_);
104 //
105 // std::unique_ptr<TaskType> input_task;
106 //
107 // ... Enqueue path ...
108 //
109 // {
110 //   mutex_lock l(mu_);
111 //   std::shared_ptr<BatchInputTask<TaskType>> batch_input_task =
112 //       ConstructLazyBatchWithoutSplit(input_task);
113 //
114 //   std::vector<std::unique_ptr<BatchInputTaskHandle<TaskType>>> task_handles;
115 //   input_batch->ToTaskHandles(&task_handles);
116 //   for (int i = 0; i < task_handles.size(); ++i) {
117 //     EnqueueTaskHandleIntoDeque(deque_);
118 //   }
119 //
120 // ... Dequeue path ...
121 // std::unique_ptr<Batch<BatchInputTaskHandle<TaskType>>> handles_to_schedule;
122 // {
123 //    mutex_lock l(mu_);
124 //    ... HasBatchToSchedule could be customized or specialized
125 //    ... (e.g., readiness depending on enqueue time)
126 //    if (HasBatchToSchedule(deque_)) {
127 //      handles_to_schedule = std::move(deque_.front());
128 //      deque_.pop_front();
129 //    }
130 // }
131 // ...... `mu_` is released ......
132 //
133 // std::vector<std::unique_ptr<BatchInputTaskHandle<TaskType>>> tasks_in_batch =
134 //     RemoveAllTasksFromBatch(handles_to_schedule);
135 //
136 // std::unique_ptr<Batch<TaskType>> batch_to_schedule;
137 // for (int i = 0; i < tasks_in_batch.size(); i++) {
138 //   batch_to_schedule->AddTask(std::move(tasks_in_batch[i]->GetSplitTask()));
139 // }
140 // batch_to_schedule->Close();
141 //
142 // `batch_to_schedule` is ready for schedule.
143 template <typename TaskType>
144 class BatchInputTask
145     : public std::enable_shared_from_this<BatchInputTask<TaskType>> {
146  public:
147   using BatchSplitFunc = std::function<Status(
148       std::unique_ptr<TaskType>* input_task, int first_output_task_size,
149       int input_batch_size_limit,
150       std::vector<std::unique_ptr<TaskType>>* output_tasks)>;
151 
152   // TODO(b/194294263):
153   // Add a SplitMetadataFunc in constructor, so users of this class specify
154   // both how to split, and how to compute split metadata in a consistent way.
155   BatchInputTask(std::unique_ptr<TaskType> input_task,
156                  int open_batch_remaining_slot, int batch_size_limit,
157                  BatchSplitFunc split_func);
158 
159   // Outputs the task handles for the input task.
160   // Each task handle represents a slice of task after input task is split, and
161   // could evaluate to that slice.
162   //
163   // NOTE:
164   // Each task handle in `output_task_handles` takes ownership of a reference of
165   // this BatchInputTask.
166   void ToTaskHandles(
167       std::vector<std::unique_ptr<BatchInputTaskHandle<TaskType>>>*
168           output_task_handles);
169 
170  private:
171   friend class BatchInputTaskHandle<TaskType>;
172   template <typename T>
173   friend class internal::BatchInputTaskTestAccess;
174   // Following method exposes split metadata of this task.
175   // Metadata are used to determine batch construction so needed before split
176   // happens.
177 
178   // Returns the task size of N-th batch; N is `split_id`.
179   int GetTaskSize(int split_id) const;
180 
181   // Task size of `input_task`
182   size_t size() const;
183 
184   // The number of batches the input spans.
185   int num_batches() const;
186 
187   // The number of new batches this input adds.
188   int num_new_batches() const;
189 
190   // The task size of the head batch.
191   int head_batch_task_size() const;
192 
193   // The task size of the last batch.
194   int tail_batch_task_size() const;
195 
196   std::unique_ptr<TaskType> GetSplitTask(int split_id);
197 
198   Status SplitBatches(std::vector<std::unique_ptr<TaskType>>* output_tasks);
199 
200   std::unique_ptr<TaskType> input_task_;
201 
202   const int input_task_size_ = 0;
203   const int open_batch_remaining_slot_;
204 
205   const int batch_size_limit_;
206 
207   const BatchSplitFunc split_func_;
208 
209   // The number of batches that this input appends to.
210   // Should be either zero or one.
211   const int num_batches_reused_ = 0;
212 
213   // The number of batches this input spans over.
214   int num_batches_ = 0;
215 
216   // The task size of the last batch.
217   int tail_batch_task_size_;
218 
219   // The task size of the first batch.
220   int head_batch_task_size_;
221 
222   mutable absl::once_flag once_;
223 
224   std::vector<std::unique_ptr<TaskType>> task_splits_;
225   Status split_status_;
226 };
227 
228 //
229 // Implementation details. API readers may skip.
230 //
231 
232 template <typename TaskType>
BatchInputTaskHandle(std::shared_ptr<BatchInputTask<TaskType>> batch_input_task,int split_id,size_t task_size)233 BatchInputTaskHandle<TaskType>::BatchInputTaskHandle(
234     std::shared_ptr<BatchInputTask<TaskType>> batch_input_task, int split_id,
235     size_t task_size)
236     : batch_input_task_(batch_input_task),
237       split_id_(split_id),
238       task_size_(task_size) {}
239 
240 template <typename TaskType>
GetSplitTask()241 std::unique_ptr<TaskType> BatchInputTaskHandle<TaskType>::GetSplitTask() {
242   if (once_.load(std::memory_order_acquire)) {
243     return nullptr;
244   }
245   once_.store(true, std::memory_order_release);
246   return batch_input_task_->GetSplitTask(split_id_);
247 }
248 
249 template <typename TaskType>
BatchInputTask(std::unique_ptr<TaskType> input_task,int open_batch_remaining_slot,int batch_size_limit,std::function<Status (std::unique_ptr<TaskType> * input_task,int first_output_task_size,int input_batch_size_limit,std::vector<std::unique_ptr<TaskType>> * output_tasks)> split_func)250 BatchInputTask<TaskType>::BatchInputTask(
251     std::unique_ptr<TaskType> input_task, int open_batch_remaining_slot,
252     int batch_size_limit,
253     std::function<Status(std::unique_ptr<TaskType>* input_task,
254                          int first_output_task_size, int input_batch_size_limit,
255                          std::vector<std::unique_ptr<TaskType>>* output_tasks)>
256         split_func)
257     : input_task_(std::move(input_task)),
258       input_task_size_(input_task_->size()),
259       open_batch_remaining_slot_(open_batch_remaining_slot),
260       batch_size_limit_(batch_size_limit),
261       split_func_(split_func),
262       num_batches_reused_((open_batch_remaining_slot_ > 0) ? 1 : 0) {
263   // The total task size starting from current open batch, after this task is
264   // enqueued.
265   const int task_size_from_open_batch =
266       (open_batch_remaining_slot_ > 0)
267           ? (input_task_size_ + batch_size_limit_ - open_batch_remaining_slot_)
268           : input_task_size_;
269 
270   num_batches_ =
271       (task_size_from_open_batch + batch_size_limit_ - 1) / batch_size_limit_;
272 
273   if (open_batch_remaining_slot_ == 0) {
274     head_batch_task_size_ = std::min(input_task_size_, batch_size_limit_);
275   } else {
276     head_batch_task_size_ = (input_task_size_ >= open_batch_remaining_slot_)
277                                 ? open_batch_remaining_slot_
278                                 : input_task_size_;
279   }
280   if (input_task_size_ <= open_batch_remaining_slot_) {
281     tail_batch_task_size_ = input_task_size_;
282   } else {
283     tail_batch_task_size_ = task_size_from_open_batch % batch_size_limit_;
284     if (tail_batch_task_size_ == 0) {
285       tail_batch_task_size_ = batch_size_limit_;
286     }
287   }
288 }
289 
290 template <typename TaskType>
size()291 size_t BatchInputTask<TaskType>::size() const {
292   return input_task_size_;
293 }
294 
295 template <typename TaskType>
num_batches()296 int BatchInputTask<TaskType>::num_batches() const {
297   return num_batches_;
298 }
299 
300 template <typename TaskType>
num_new_batches()301 int BatchInputTask<TaskType>::num_new_batches() const {
302   return num_batches_ - num_batches_reused_;
303 }
304 
305 template <typename TaskType>
head_batch_task_size()306 int BatchInputTask<TaskType>::head_batch_task_size() const {
307   return head_batch_task_size_;
308 }
309 
310 template <typename TaskType>
tail_batch_task_size()311 int BatchInputTask<TaskType>::tail_batch_task_size() const {
312   return tail_batch_task_size_;
313 }
314 
315 template <typename TaskType>
ToTaskHandles(std::vector<std::unique_ptr<BatchInputTaskHandle<TaskType>>> * task_handles)316 void BatchInputTask<TaskType>::ToTaskHandles(
317     std::vector<std::unique_ptr<BatchInputTaskHandle<TaskType>>>*
318         task_handles) {
319   task_handles->resize(num_batches_);
320   for (int i = 0; i < num_batches_; i++) {
321     (*task_handles)[i] = std::make_unique<BatchInputTaskHandle<TaskType>>(
322         this->shared_from_this(), i, GetTaskSize(i));
323   }
324 }
325 
326 template <typename TaskType>
GetTaskSize(int split_id)327 int BatchInputTask<TaskType>::GetTaskSize(int split_id) const {
328   if (split_id < 0 || split_id >= num_batches_) {
329     return 0;
330   }
331   if (split_id == 0) {
332     return head_batch_task_size_;
333   }
334   if (split_id == num_batches_ - 1) {
335     return tail_batch_task_size_;
336   }
337 
338   return batch_size_limit_;
339 }
340 
341 template <typename TaskType>
GetSplitTask(int split_id)342 std::unique_ptr<TaskType> BatchInputTask<TaskType>::GetSplitTask(int split_id) {
343   absl::call_once(once_,
344                   [this]() { split_status_ = SplitBatches(&task_splits_); });
345   if (!split_status_.ok()) {
346     return nullptr;
347   }
348   if (split_id >= 0 && split_id < task_splits_.size()) {
349     return std::move(task_splits_[split_id]);
350   }
351   return nullptr;
352 }
353 
354 template <typename TaskType>
SplitBatches(std::vector<std::unique_ptr<TaskType>> * output_tasks)355 Status BatchInputTask<TaskType>::SplitBatches(
356     std::vector<std::unique_ptr<TaskType>>* output_tasks) {
357   return split_func_(&input_task_, open_batch_remaining_slot_,
358                      batch_size_limit_, output_tasks);
359 }
360 
361 }  // namespace serving
362 }  // namespace tensorflow
363 
364 #endif  // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_INPUT_TASK_H_
365