• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_RESOURCE_BASE_H_
17 #define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_RESOURCE_BASE_H_
18 
19 #include <map>
20 
21 #include "absl/strings/str_join.h"
22 #include "tensorflow/core/common_runtime/cost_measurement_registry.h"
23 #include "tensorflow/core/common_runtime/request_cost.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/resource_mgr.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h"
29 #include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
30 #include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h"
31 #include "tensorflow/core/kernels/batching_util/threadsafe_status.h"
32 #include "tensorflow/core/platform/context.h"
33 #include "tensorflow/core/platform/status.h"
34 #include "tensorflow/core/platform/thread_annotations.h"
35 
36 namespace tensorflow {
37 namespace serving {
38 
39 // Base class for resource that encapsulating the state and logic for batching
40 // tensors.
41 class BatchResourceBase : public ResourceBase {
42  public:
43   // Given a BatchTask (from one op invocation) with 'num_outputs'== M and
44   // splitted into N sub tasks, TensorMatrix is a N X M matrix.
45   // Namely, TensorMatrix[i][j] indicates the i-th split tensor of j-th output;
46   // concatenating tensors along the 2nd dimension gives a output tensor.
47   typedef std::vector<std::vector<Tensor>> TensorMatrix;
48 
49   // Ingests data from one invocation of the batch op. The data is enqueued to
50   // be combined with others into a batch, asynchronously.
51   Status RegisterInput(int64_t guid, OpKernelContext* context,
52                        const string& batcher_queue_name,
53                        AsyncOpKernel::DoneCallback done_callback);
54 
55  public:
56   // One task to be batched, corresponds to a `slice` of input from one batch-op
57   // invocation.
58   //
59   // Given input from one batch-op invocation, a `slice` of this input is:
60   // 1) Split each Tensor in `BatchTask::inputs` along the 0th dimension.
61   // 2) 'split_index' is calculated along the 0-th dimension.
62   //
63   // Note input from one batch-op invocation is valid and considered a
64   // specialized `slice`.
65   struct BatchTask : public tensorflow::serving::BatchTask {
66     // A unique ID to identify this invocation of Batch.
67     int64 guid;
68 
69     Context propagated_context;
70 
71     std::vector<Tensor> inputs;
72     std::vector<Tensor> captured_inputs;
73     OpKernelContext* context;
74     AsyncOpKernel::DoneCallback done_callback;
75 
76     // The index of this split, along the 0-th dimension of input from op
77     // invocation.
78     int split_index = 0;
79 
80     // Two-dimensional tensor matrix, ownership shared by:
81     // 1) each split of task (to fill one row in this matrix)
82     // and
83     // 2) callback that runs to merge output of individual splits for an op
84     // invocation, after all splits complete.
85     std::shared_ptr<TensorMatrix> output;
86 
87     // 'status' records error (could be from any split) if at least one split
88     // returns error, OK otherwise.
89     // Ownership is shared by individual splits and callback.
90     std::shared_ptr<ThreadSafeStatus> status;
91 
92     bool is_partial = false;
93 
94     uint64 start_time;
95 
sizeBatchTask96     size_t size() const override { return inputs[0].shape().dim_size(0); }
97 
98     // Create a split task from this one. The caller needs to setup the inputs
99     // of the new task
100     std::unique_ptr<BatchTask> CreateSplitTask(
101         int split_index, AsyncOpKernel::DoneCallback done_callback);
102 
103     // RequestCost is for collecting the cost and must outlive the batching
104     // processing.
105     //
106     // For example, to collect cost in rpc processing, `request_cost` is owned
107     // by rpc handler and points to the RequestCost of an rpc which provides
108     // the inputs to this BatchTask.
109     //
110     // After the batch processing, the request cost will be incremented with
111     // this task's processing costs.
112     RequestCost* request_cost = nullptr;
113 
114    protected:
CreateDerivedTaskBatchTask115     virtual std::unique_ptr<BatchTask> CreateDerivedTask() {
116       return std::make_unique<BatchTask>();
117     }
118   };
119 
120   // Appending a T suffix to make the type alias different to those in
121   // tensorflow::serving namespace, because some versions of compiler complain
122   // about changing meaning of the symbols.
123   using BatcherT = SharedBatchScheduler<BatchResourceBase::BatchTask>;
124   using AdaptiveBatcherT =
125       AdaptiveSharedBatchScheduler<BatchResourceBase::BatchTask>;
126   using BatcherQueueT = BatchScheduler<BatchResourceBase::BatchTask>;
127   using BatchT = Batch<BatchResourceBase::BatchTask>;
128 
BatchResourceBase(bool has_process_batch_function,std::shared_ptr<BatcherT> batcher,const BatcherT::QueueOptions & batcher_queue_options,std::vector<int32> allowed_batch_sizes)129   BatchResourceBase(bool has_process_batch_function,
130                     std::shared_ptr<BatcherT> batcher,
131                     const BatcherT::QueueOptions& batcher_queue_options,
132                     std::vector<int32> allowed_batch_sizes)
133       : has_process_batch_function_(has_process_batch_function),
134         batcher_(std::move(batcher)),
135         batcher_queue_options_(batcher_queue_options),
136         allowed_batch_sizes_(std::move(allowed_batch_sizes)) {
137     allowed_batch_sizes_str_ = absl::StrJoin(allowed_batch_sizes_, ",");
138   }
139 
BatchResourceBase(bool has_process_batch_function,std::shared_ptr<AdaptiveBatcherT> batcher,const AdaptiveBatcherT::QueueOptions & batcher_queue_options,std::vector<int32> allowed_batch_sizes)140   BatchResourceBase(bool has_process_batch_function,
141                     std::shared_ptr<AdaptiveBatcherT> batcher,
142                     const AdaptiveBatcherT::QueueOptions& batcher_queue_options,
143                     std::vector<int32> allowed_batch_sizes)
144       : has_process_batch_function_(has_process_batch_function),
145         adaptive_batcher_(std::move(batcher)),
146         adaptive_batcher_queue_options_(batcher_queue_options),
147         allowed_batch_sizes_(std::move(allowed_batch_sizes)) {}
148 
149   static BatcherT::QueueOptions GetBatcherQueueOptions(
150       int32_t num_batch_threads, int32_t max_batch_size,
151       int32_t batch_timeout_micros, int32_t max_enqueued_batches,
152       const std::vector<int32>& allowed_batch_sizes,
153       bool enable_large_batch_splitting);
154 
155   static AdaptiveBatcherT::QueueOptions GetAdaptiveBatcherQueueOptions(
156       int32_t max_batch_size, int32_t batch_timeout_micros,
157       int32_t max_enqueued_batches, bool enable_large_batch_splitting,
158       const std::vector<int32>& allowed_batch_sizes);
159 
160   // Split 'input' of 'input_task_ptr' along 0th dimension, into a list of
161   // 'output_tasks'.
162   // Task sizes are determined by
163   // 1) open_batch_remaining_slot
164   // 2) max_batch_size
165   // 3) size-of-input-task
166   // in a way that
167   // 1) Task sizes add up to `size-of-input-task`.
168   // 2) Task sizes from left to right are like
169   //    [open_batch_remaining_slot, max_batch_size, max_batch_size, ...,
170   //    `size-of-input-task` - `sum-of-previous-elements`].
171   //
172   // REQUIRES:
173   // Caller should make sure size-of-input-task is greater than
174   // open_batch_remaining_slot.
175   static Status SplitInputTask(
176       std::unique_ptr<BatchTask>* input_task_ptr, int open_batch_remaining_slot,
177       int max_batch_size,
178       std::vector<std::unique_ptr<BatchTask>>* output_tasks);
179 
180  private:
181   // Implementation of calling the process batch function.
182   virtual void ProcessFuncBatchImpl(
183       const BatchResourceBase::BatchTask& last_task,
184       absl::Span<const Tensor> inputs, std::vector<Tensor>* combined_outputs,
185       std::function<void(const Status&)> done) const = 0;
186 
187   // Factory method for creating a BatchTask, overridable by subclasses.
188   virtual Status CreateBatchTask(
189       OpKernelContext* context,
190       std::unique_ptr<BatchResourceBase::BatchTask>* output) const;
191 
192   // Validates that it's legal to combine the tasks in 'batch' into a batch.
193   // Assumes the batch is non-empty.
194   static Status ValidateBatch(const BatchT& batch);
195 
196   // Returns the smallest entry in 'allowed_batch_sizes_' that is greater than
197   // or equal to 'batch_size'. If 'allowed_batch_sizes_' is empty, simply
198   // returns 'batch_size'.
199   int RoundToLowestAllowedBatchSize(int batch_size) const;
200 
201   Status ConcatInputTensors(const BatchT& batch, OpKernelContext* context,
202                             std::vector<Tensor>* concatenated_tensors) const;
203 
204   Status SplitOutputTensors(const std::vector<Tensor>& combined_outputs,
205                             BatchT* batch) const;
206 
207   void ProcessFuncBatch(std::unique_ptr<BatchT> batch) const;
208 
209   // Processes a batch of one or more BatchTask entries.
210   void ProcessBatch(std::unique_ptr<BatchT> batch) const;
211 
212   // Emits an index tensor, which the Unbatch op will use to un-concatenate
213   // the tensor and attribute the pieces to the right batch keys. The index
214   // tensor contains, for each input: [batch_key, start_offset, end_offset]
215   // where start_offset and end_offset represent the range of entries in the
216   // concatenated tensors that belong to that input.
217   //
218   // Emits the result to the output at 'output_index' using 'context'.
219   static Status EmitIndexTensor(OpKernelContext* context, const BatchT& batch,
220                                 int output_index);
221 
222   // Looks up the batcher queue for 'queue_name'. If it did't previously exist,
223   // creates it.
224   Status LookupOrCreateBatcherQueue(const string& queue_name,
225                                     BatcherQueueT** queue);
226 
227   // Splits the batch cost to each task.
228   //
229   // Inputs:
230   // 1) batch_cost_measurement, which provides the total cost and cost type;
231   // 2) processed_size, it's the batch size plus the padding amount;
232   // 3) batch, provides the batch size.
233   //
234   // Outputs:
235   // The request_cost in each batch task will be updated. This function will use
236   // two approaches to split the batch cost (if it's non-zero), thus two costs
237   // will be output. One approach splits the cost proportionally to the each
238   // task's size, but paddings do not share any cost, while the other splits the
239   // cost proportionally to each task or padding's size.
240   void SplitBatchCost(CostMeasurement* batch_cost_measurement,
241                       const int64_t processed_size, BatchT& batch) const;
242 
243   // True if user specified a batch processing function for this resource.
244   const bool has_process_batch_function_;
245   // A batch scheduler, and options for creating queues.
246   std::shared_ptr<BatcherT> batcher_;
247   BatcherT::QueueOptions batcher_queue_options_;
248 
249   // A batch scheduler, and options for creating queues.
250   std::shared_ptr<AdaptiveBatcherT> adaptive_batcher_;
251   AdaptiveBatcherT::QueueOptions adaptive_batcher_queue_options_;
252 
253   // A collection of batcher queues, keyed on queue name.
254   // TODO(olston): Garbage-collect unused queues (perhaps simply remove empty
255   // ones (with a time delay?); it's okay if they get recreated later).
256   mutable mutex batcher_queues_mu_;
257   std::map<string, std::unique_ptr<BatcherQueueT>> batcher_queues_
258       TF_GUARDED_BY(batcher_queues_mu_);
259 
260   std::vector<int32> allowed_batch_sizes_;
261   // A concatenated string of <allowed_batch_sizes_>, separated by ",". This is
262   // used to record batching parameter.
263   string allowed_batch_sizes_str_;
264 };
265 
266 }  // namespace serving
267 }  // namespace tensorflow
268 
269 #endif  // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_RESOURCE_BASE_H_
270