1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_INPUT_TASK_H_
17 #define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_INPUT_TASK_H_
18
19 #include <algorithm>
20 #include <atomic>
21 #include <functional>
22 #include <memory>
23 #include <utility>
24
25 #include "absl/base/call_once.h"
26 #include "absl/synchronization/mutex.h"
27 #include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
28 #include "tensorflow/core/kernels/batching_util/concat_split_util.h"
29 #include "tensorflow/core/platform/errors.h"
30 #include "tensorflow/core/platform/thread_annotations.h"
31 #include "tensorflow/core/util/incremental_barrier.h"
32
33 namespace tensorflow {
34 namespace serving {
35
36 namespace internal {
37 template <typename TaskType>
38 class BatchInputTaskHandleTestAccess;
39
40 template <typename TaskType>
41 class BatchInputTaskTestAccess;
42 } // namespace internal
43
44 template <typename TaskType>
45 class BatchInputTask;
46
47 // A RAII-style object that holds a ref-counted batch-input-task, and
48 // represents a slice of batch-input-task.
49
50 // To be handed out to callers of `BatchInputTask::ToTaskHandles` quickly
51 // (i.e. not necessarily waiting for input split)
52 //
53 // `BatchInputTaskHandle::GetSplitTask` evaluates to the slice of task.
54 template <typename TaskType>
55 class BatchInputTaskHandle : public BatchTask {
56 public:
57 BatchInputTaskHandle(
58 std::shared_ptr<BatchInputTask<TaskType>> batch_input_task, int split_id,
59 size_t task_size);
60
61 // Should be called once. Returns nullptr on subsequent calls.
62 std::unique_ptr<TaskType> GetSplitTask();
63
64 // Returns the size of this task.
size()65 size_t size() const override { return task_size_; }
66
67 private:
68 template <typename T>
69 friend class internal::BatchInputTaskHandleTestAccess;
70
split_id()71 int split_id() const { return split_id_; }
72
73 std::shared_ptr<BatchInputTask<TaskType>> batch_input_task_;
74
75 // The handle evaluates to the N-th slice of original task, and
76 // N is `split_id_`.
77 const int split_id_;
78
79 const size_t task_size_;
80
81 std::atomic<bool> once_{false};
82 };
83
84 // BatchInputTask encapsulates a input (`input_task`) to be batched and the
85 // information to get task splits after it's enqueued, so as to support lazy
86 // split of a task.
87 //
88 // Input split could reduce excessive padding for efficiency; lazy split
89 // moves task-split out of the critical path of enqueue and dequeue and reduces
90 // contention.
91 //
92 // BatchInputTask is thread safe.
93 //
94 // Usage
95 //
96 // ... a deque with frequent enqueue and dequeue operations ...
97 // ... Note, a deque of Batch of BatchInputTaskHandle is used to form batches
98 // at enqueue time (split is lazy at deque time);
99 // ... For use cases to form batches at dequeue time, we can use a deque of
100 // BatchInputTaskHandle directly, and "peek" metadata to form a batch by
101 // then.
102 // std::deque<std::unique_ptr<Batch<BatchInputTaskHandle<TaskType>>>> deque_
103 // TF_GUARDED_BY(mu_);
104 //
105 // std::unique_ptr<TaskType> input_task;
106 //
107 // ... Enqueue path ...
108 //
109 // {
110 // mutex_lock l(mu_);
111 // std::shared_ptr<BatchInputTask<TaskType>> batch_input_task =
112 // ConstructLazyBatchWithoutSplit(input_task);
113 //
114 // std::vector<std::unique_ptr<BatchInputTaskHandle<TaskType>>> task_handles;
115 // input_batch->ToTaskHandles(&task_handles);
116 // for (int i = 0; i < task_handles.size(); ++i) {
117 // EnqueueTaskHandleIntoDeque(deque_);
118 // }
119 //
120 // ... Dequeue path ...
121 // std::unique_ptr<Batch<BatchInputTaskHandle<TaskType>>> handles_to_schedule;
122 // {
123 // mutex_lock l(mu_);
124 // ... HasBatchToSchedule could be customized or specialized
125 // ... (e.g., readiness depending on enqueue time)
126 // if (HasBatchToSchedule(deque_)) {
127 // handles_to_schedule = std::move(deque_.front());
128 // deque_.pop_front();
129 // }
130 // }
131 // ...... `mu_` is released ......
132 //
133 // std::vector<std::unique_ptr<BatchInputTaskHandle<TaskType>>> tasks_in_batch =
134 // RemoveAllTasksFromBatch(handles_to_schedule);
135 //
136 // std::unique_ptr<Batch<TaskType>> batch_to_schedule;
137 // for (int i = 0; i < tasks_in_batch.size(); i++) {
138 // batch_to_schedule->AddTask(std::move(tasks_in_batch[i]->GetSplitTask()));
139 // }
140 // batch_to_schedule->Close();
141 //
142 // `batch_to_schedule` is ready for schedule.
143 template <typename TaskType>
144 class BatchInputTask
145 : public std::enable_shared_from_this<BatchInputTask<TaskType>> {
146 public:
147 using BatchSplitFunc = std::function<Status(
148 std::unique_ptr<TaskType>* input_task, int first_output_task_size,
149 int input_batch_size_limit,
150 std::vector<std::unique_ptr<TaskType>>* output_tasks)>;
151
152 // TODO(b/194294263):
153 // Add a SplitMetadataFunc in constructor, so users of this class specify
154 // both how to split, and how to compute split metadata in a consistent way.
155 BatchInputTask(std::unique_ptr<TaskType> input_task,
156 int open_batch_remaining_slot, int batch_size_limit,
157 BatchSplitFunc split_func);
158
159 // Outputs the task handles for the input task.
160 // Each task handle represents a slice of task after input task is split, and
161 // could evaluate to that slice.
162 //
163 // NOTE:
164 // Each task handle in `output_task_handles` takes ownership of a reference of
165 // this BatchInputTask.
166 void ToTaskHandles(
167 std::vector<std::unique_ptr<BatchInputTaskHandle<TaskType>>>*
168 output_task_handles);
169
170 private:
171 friend class BatchInputTaskHandle<TaskType>;
172 template <typename T>
173 friend class internal::BatchInputTaskTestAccess;
174 // Following method exposes split metadata of this task.
175 // Metadata are used to determine batch construction so needed before split
176 // happens.
177
178 // Returns the task size of N-th batch; N is `split_id`.
179 int GetTaskSize(int split_id) const;
180
181 // Task size of `input_task`
182 size_t size() const;
183
184 // The number of batches the input spans.
185 int num_batches() const;
186
187 // The number of new batches this input adds.
188 int num_new_batches() const;
189
190 // The task size of the head batch.
191 int head_batch_task_size() const;
192
193 // The task size of the last batch.
194 int tail_batch_task_size() const;
195
196 std::unique_ptr<TaskType> GetSplitTask(int split_id);
197
198 Status SplitBatches(std::vector<std::unique_ptr<TaskType>>* output_tasks);
199
200 std::unique_ptr<TaskType> input_task_;
201
202 const int input_task_size_ = 0;
203 const int open_batch_remaining_slot_;
204
205 const int batch_size_limit_;
206
207 const BatchSplitFunc split_func_;
208
209 // The number of batches that this input appends to.
210 // Should be either zero or one.
211 const int num_batches_reused_ = 0;
212
213 // The number of batches this input spans over.
214 int num_batches_ = 0;
215
216 // The task size of the last batch.
217 int tail_batch_task_size_;
218
219 // The task size of the first batch.
220 int head_batch_task_size_;
221
222 mutable absl::once_flag once_;
223
224 std::vector<std::unique_ptr<TaskType>> task_splits_;
225 Status split_status_;
226 };
227
228 //
229 // Implementation details. API readers may skip.
230 //
231
232 template <typename TaskType>
BatchInputTaskHandle(std::shared_ptr<BatchInputTask<TaskType>> batch_input_task,int split_id,size_t task_size)233 BatchInputTaskHandle<TaskType>::BatchInputTaskHandle(
234 std::shared_ptr<BatchInputTask<TaskType>> batch_input_task, int split_id,
235 size_t task_size)
236 : batch_input_task_(batch_input_task),
237 split_id_(split_id),
238 task_size_(task_size) {}
239
240 template <typename TaskType>
GetSplitTask()241 std::unique_ptr<TaskType> BatchInputTaskHandle<TaskType>::GetSplitTask() {
242 if (once_.load(std::memory_order_acquire)) {
243 return nullptr;
244 }
245 once_.store(true, std::memory_order_release);
246 return batch_input_task_->GetSplitTask(split_id_);
247 }
248
249 template <typename TaskType>
BatchInputTask(std::unique_ptr<TaskType> input_task,int open_batch_remaining_slot,int batch_size_limit,std::function<Status (std::unique_ptr<TaskType> * input_task,int first_output_task_size,int input_batch_size_limit,std::vector<std::unique_ptr<TaskType>> * output_tasks)> split_func)250 BatchInputTask<TaskType>::BatchInputTask(
251 std::unique_ptr<TaskType> input_task, int open_batch_remaining_slot,
252 int batch_size_limit,
253 std::function<Status(std::unique_ptr<TaskType>* input_task,
254 int first_output_task_size, int input_batch_size_limit,
255 std::vector<std::unique_ptr<TaskType>>* output_tasks)>
256 split_func)
257 : input_task_(std::move(input_task)),
258 input_task_size_(input_task_->size()),
259 open_batch_remaining_slot_(open_batch_remaining_slot),
260 batch_size_limit_(batch_size_limit),
261 split_func_(split_func),
262 num_batches_reused_((open_batch_remaining_slot_ > 0) ? 1 : 0) {
263 // The total task size starting from current open batch, after this task is
264 // enqueued.
265 const int task_size_from_open_batch =
266 (open_batch_remaining_slot_ > 0)
267 ? (input_task_size_ + batch_size_limit_ - open_batch_remaining_slot_)
268 : input_task_size_;
269
270 num_batches_ =
271 (task_size_from_open_batch + batch_size_limit_ - 1) / batch_size_limit_;
272
273 if (open_batch_remaining_slot_ == 0) {
274 head_batch_task_size_ = std::min(input_task_size_, batch_size_limit_);
275 } else {
276 head_batch_task_size_ = (input_task_size_ >= open_batch_remaining_slot_)
277 ? open_batch_remaining_slot_
278 : input_task_size_;
279 }
280 if (input_task_size_ <= open_batch_remaining_slot_) {
281 tail_batch_task_size_ = input_task_size_;
282 } else {
283 tail_batch_task_size_ = task_size_from_open_batch % batch_size_limit_;
284 if (tail_batch_task_size_ == 0) {
285 tail_batch_task_size_ = batch_size_limit_;
286 }
287 }
288 }
289
290 template <typename TaskType>
size()291 size_t BatchInputTask<TaskType>::size() const {
292 return input_task_size_;
293 }
294
295 template <typename TaskType>
num_batches()296 int BatchInputTask<TaskType>::num_batches() const {
297 return num_batches_;
298 }
299
300 template <typename TaskType>
num_new_batches()301 int BatchInputTask<TaskType>::num_new_batches() const {
302 return num_batches_ - num_batches_reused_;
303 }
304
305 template <typename TaskType>
head_batch_task_size()306 int BatchInputTask<TaskType>::head_batch_task_size() const {
307 return head_batch_task_size_;
308 }
309
310 template <typename TaskType>
tail_batch_task_size()311 int BatchInputTask<TaskType>::tail_batch_task_size() const {
312 return tail_batch_task_size_;
313 }
314
315 template <typename TaskType>
ToTaskHandles(std::vector<std::unique_ptr<BatchInputTaskHandle<TaskType>>> * task_handles)316 void BatchInputTask<TaskType>::ToTaskHandles(
317 std::vector<std::unique_ptr<BatchInputTaskHandle<TaskType>>>*
318 task_handles) {
319 task_handles->resize(num_batches_);
320 for (int i = 0; i < num_batches_; i++) {
321 (*task_handles)[i] = std::make_unique<BatchInputTaskHandle<TaskType>>(
322 this->shared_from_this(), i, GetTaskSize(i));
323 }
324 }
325
326 template <typename TaskType>
GetTaskSize(int split_id)327 int BatchInputTask<TaskType>::GetTaskSize(int split_id) const {
328 if (split_id < 0 || split_id >= num_batches_) {
329 return 0;
330 }
331 if (split_id == 0) {
332 return head_batch_task_size_;
333 }
334 if (split_id == num_batches_ - 1) {
335 return tail_batch_task_size_;
336 }
337
338 return batch_size_limit_;
339 }
340
341 template <typename TaskType>
GetSplitTask(int split_id)342 std::unique_ptr<TaskType> BatchInputTask<TaskType>::GetSplitTask(int split_id) {
343 absl::call_once(once_,
344 [this]() { split_status_ = SplitBatches(&task_splits_); });
345 if (!split_status_.ok()) {
346 return nullptr;
347 }
348 if (split_id >= 0 && split_id < task_splits_.size()) {
349 return std::move(task_splits_[split_id]);
350 }
351 return nullptr;
352 }
353
354 template <typename TaskType>
SplitBatches(std::vector<std::unique_ptr<TaskType>> * output_tasks)355 Status BatchInputTask<TaskType>::SplitBatches(
356 std::vector<std::unique_ptr<TaskType>>* output_tasks) {
357 return split_func_(&input_task_, open_batch_remaining_slot_,
358 batch_size_limit_, output_tasks);
359 }
360
361 } // namespace serving
362 } // namespace tensorflow
363
364 #endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_INPUT_TASK_H_
365