• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_RESOURCE_BASE_H_
17 #define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_RESOURCE_BASE_H_
18 
19 #include <map>
20 
21 #include "absl/strings/str_join.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/resource_mgr.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/tensor_shape.h"
26 #include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h"
27 #include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
28 #include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h"
29 #include "tensorflow/core/kernels/batching_util/threadsafe_status.h"
30 #include "tensorflow/core/platform/context.h"
31 #include "tensorflow/core/platform/status.h"
32 #include "tensorflow/core/platform/thread_annotations.h"
33 
34 namespace tensorflow {
35 namespace serving {
36 
37 // Base class for resource that encapsulating the state and logic for batching
38 // tensors.
39 class BatchResourceBase : public ResourceBase {
40  public:
41   // Given a BatchTask (from one op invocation) with 'num_outputs'== M and
42   // splitted into N sub tasks, TensorMatrix is a N X M matrix.
43   // Namely, TensorMatrix[i][j] indicates the i-th split tensor of j-th output;
44   // concatenating tensors along the 2nd dimension gives a output tensor.
45   typedef std::vector<std::vector<Tensor>> TensorMatrix;
46 
47   // Ingests data from one invocation of the batch op. The data is enqueued to
48   // be combined with others into a batch, asynchronously.
49   Status RegisterInput(int64 guid, OpKernelContext* context,
50                        const string& batcher_queue_name,
51                        AsyncOpKernel::DoneCallback done_callback);
52 
53  public:
54   // One task to be batched, corresponds to a `slice` of input from one batch-op
55   // invocation.
56   //
57   // Given input from one batch-op invocation, a `slice` of this input is:
58   // 1) Split each Tensor in `BatchTask::inputs` along the 0th dimension.
59   // 2) 'split_index' is calculated along the 0-th dimension.
60   //
61   // Note input from one batch-op invocation is valid and considered a
62   // specialized `slice`.
63   struct BatchTask : public tensorflow::serving::BatchTask {
64     // A unique ID to identify this invocation of Batch.
65     int64 guid;
66 
67     Context propagated_context;
68 
69     std::vector<Tensor> inputs;
70     std::vector<Tensor> captured_inputs;
71     OpKernelContext* context;
72     AsyncOpKernel::DoneCallback done_callback;
73 
74     // The index of this split, along the 0-th dimension of input from op
75     // invocation.
76     int split_index = 0;
77 
78     // Two-dimensional tensor matrix, ownership shared by:
79     // 1) each split of task (to fill one row in this matrix)
80     // and
81     // 2) callback that runs to merge output of individual splits for an op
82     // invocation, after all splits complete.
83     std::shared_ptr<TensorMatrix> output;
84 
85     // 'status' records error (could be from any split) if at least one split
86     // returns error, OK otherwise.
87     // Ownership is shared by individual splits and callback.
88     std::shared_ptr<ThreadSafeStatus> status;
89 
90     bool is_partial = false;
91 
92     uint64 start_time;
93 
sizeBatchTask94     size_t size() const override { return inputs[0].shape().dim_size(0); }
95 
96     // Create a split task from this one. The caller needs to setup the inputs
97     // of the new task
98     std::unique_ptr<BatchTask> CreateSplitTask(
99         int split_index, AsyncOpKernel::DoneCallback done_callback);
100 
101    protected:
CreateDerivedTaskBatchTask102     virtual std::unique_ptr<BatchTask> CreateDerivedTask() {
103       return std::make_unique<BatchTask>();
104     }
105   };
106 
107   // Appending a T suffix to make the type alias different to those in
108   // tensorflow::serving namespace, because some versions of compiler complain
109   // about changing meaning of the symbols.
110   using BatcherT = SharedBatchScheduler<BatchResourceBase::BatchTask>;
111   using AdaptiveBatcherT =
112       AdaptiveSharedBatchScheduler<BatchResourceBase::BatchTask>;
113   using BatcherQueueT = BatchScheduler<BatchResourceBase::BatchTask>;
114   using BatchT = Batch<BatchResourceBase::BatchTask>;
115 
BatchResourceBase(bool has_process_batch_function,std::shared_ptr<BatcherT> batcher,const BatcherT::QueueOptions & batcher_queue_options,std::vector<int32> allowed_batch_sizes)116   BatchResourceBase(bool has_process_batch_function,
117                     std::shared_ptr<BatcherT> batcher,
118                     const BatcherT::QueueOptions& batcher_queue_options,
119                     std::vector<int32> allowed_batch_sizes)
120       : has_process_batch_function_(has_process_batch_function),
121         batcher_(std::move(batcher)),
122         batcher_queue_options_(batcher_queue_options),
123         allowed_batch_sizes_(std::move(allowed_batch_sizes)) {
124     allowed_batch_sizes_str_ = absl::StrJoin(allowed_batch_sizes_, ",");
125   }
126 
BatchResourceBase(bool has_process_batch_function,std::shared_ptr<AdaptiveBatcherT> batcher,const AdaptiveBatcherT::QueueOptions & batcher_queue_options,std::vector<int32> allowed_batch_sizes)127   BatchResourceBase(bool has_process_batch_function,
128                     std::shared_ptr<AdaptiveBatcherT> batcher,
129                     const AdaptiveBatcherT::QueueOptions& batcher_queue_options,
130                     std::vector<int32> allowed_batch_sizes)
131       : has_process_batch_function_(has_process_batch_function),
132         adaptive_batcher_(std::move(batcher)),
133         adaptive_batcher_queue_options_(batcher_queue_options),
134         allowed_batch_sizes_(std::move(allowed_batch_sizes)) {}
135 
136   static BatcherT::QueueOptions GetBatcherQueueOptions(
137       int32 num_batch_threads, int32 max_batch_size, int32 batch_timeout_micros,
138       int32 max_enqueued_batches, const std::vector<int32>& allowed_batch_sizes,
139       bool enable_large_batch_splitting);
140 
141   static AdaptiveBatcherT::QueueOptions GetAdaptiveBatcherQueueOptions(
142       int32 max_batch_size, int32 batch_timeout_micros,
143       int32 max_enqueued_batches, bool enable_large_batch_splitting,
144       const std::vector<int32>& allowed_batch_sizes);
145 
146  private:
147   // Implementation of calling the process batch function.
148   virtual void ProcessFuncBatchImpl(
149       const BatchResourceBase::BatchTask& last_task,
150       absl::Span<const Tensor> inputs, std::vector<Tensor>* combined_outputs,
151       std::function<void(const Status&)> done) const = 0;
152 
153   // Factory method for creating a BatchTask, overridable by subclasses.
154   virtual Status CreateBatchTask(
155       OpKernelContext* context,
156       std::unique_ptr<BatchResourceBase::BatchTask>* output) const;
157 
158   // Validates that it's legal to combine the tasks in 'batch' into a batch.
159   // Assumes the batch is non-empty.
160   static Status ValidateBatch(const BatchT& batch);
161 
162   // Returns the smallest entry in 'allowed_batch_sizes_' that is greater than
163   // or equal to 'batch_size'. If 'allowed_batch_sizes_' is empty, simply
164   // returns 'batch_size'.
165   int RoundToLowestAllowedBatchSize(int batch_size) const;
166 
167   Status ConcatInputTensors(const BatchT& batch, OpKernelContext* context,
168                             std::vector<Tensor>* concatenated_tensors) const;
169 
170   // Split 'input' of 'input_task_ptr' along 0th dimension, into a list of
171   // 'output_tasks'.
172   // Task sizes are determined by
173   // 1) open_batch_remaining_slot
174   // 2) max_batch_size
175   // 3) size-of-input-task
176   // in a way that
177   // 1) Task sizes add up to `size-of-input-task`.
178   // 2) Task sizes from left to right are like
179   //    [open_batch_remaining_slot, max_batch_size, max_batch_size, ...,
180   //    `size-of-input-task` - `sum-of-previous-elements`].
181   //
182   // REQUIRES:
183   // Caller should make sure size-of-input-task is greater than
184   // open_batch_remaining_slot.
185   static Status SplitInputTask(
186       std::unique_ptr<BatchTask>* input_task_ptr, int open_batch_remaining_slot,
187       int max_batch_size,
188       std::vector<std::unique_ptr<BatchTask>>* output_tasks);
189 
190   Status SplitOutputTensors(const std::vector<Tensor>& combined_outputs,
191                             BatchT* batch) const;
192 
193   void ProcessFuncBatch(std::unique_ptr<BatchT> batch) const;
194 
195   // Processes a batch of one or more BatchTask entries.
196   void ProcessBatch(std::unique_ptr<BatchT> batch) const;
197 
198   // Emits an index tensor, which the Unbatch op will use to un-concatenate
199   // the tensor and attribute the pieces to the right batch keys. The index
200   // tensor contains, for each input: [batch_key, start_offset, end_offset]
201   // where start_offset and end_offset represent the range of entries in the
202   // concatenated tensors that belong to that input.
203   //
204   // Emits the result to the output at 'output_index' using 'context'.
205   static Status EmitIndexTensor(OpKernelContext* context, const BatchT& batch,
206                                 int output_index);
207 
208   // Looks up the batcher queue for 'queue_name'. If it did't previously exist,
209   // creates it.
210   Status LookupOrCreateBatcherQueue(const string& queue_name,
211                                     BatcherQueueT** queue);
212 
213   // True if user specified a batch processing function for this resource.
214   const bool has_process_batch_function_;
215   // A batch scheduler, and options for creating queues.
216   std::shared_ptr<BatcherT> batcher_;
217   BatcherT::QueueOptions batcher_queue_options_;
218 
219   // A batch scheduler, and options for creating queues.
220   std::shared_ptr<AdaptiveBatcherT> adaptive_batcher_;
221   AdaptiveBatcherT::QueueOptions adaptive_batcher_queue_options_;
222 
223   // A collection of batcher queues, keyed on queue name.
224   // TODO(olston): Garbage-collect unused queues (perhaps simply remove empty
225   // ones (with a time delay?); it's okay if they get recreated later).
226   mutable mutex batcher_queues_mu_;
227   std::map<string, std::unique_ptr<BatcherQueueT>> batcher_queues_
228       TF_GUARDED_BY(batcher_queues_mu_);
229 
230   std::vector<int32> allowed_batch_sizes_;
231   // A concatenated string of <allowed_batch_sizes_>, separated by ",". This is
232   // used to record batching parameter.
233   string allowed_batch_sizes_str_;
234 };
235 
236 }  // namespace serving
237 }  // namespace tensorflow
238 
239 #endif  // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_RESOURCE_BASE_H_
240