1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_RESOURCE_BASE_H_ 17 #define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_RESOURCE_BASE_H_ 18 19 #include <map> 20 21 #include "absl/strings/str_join.h" 22 #include "tensorflow/core/framework/op_kernel.h" 23 #include "tensorflow/core/framework/resource_mgr.h" 24 #include "tensorflow/core/framework/tensor.h" 25 #include "tensorflow/core/framework/tensor_shape.h" 26 #include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h" 27 #include "tensorflow/core/kernels/batching_util/batch_scheduler.h" 28 #include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h" 29 #include "tensorflow/core/kernels/batching_util/threadsafe_status.h" 30 #include "tensorflow/core/platform/context.h" 31 #include "tensorflow/core/platform/status.h" 32 #include "tensorflow/core/platform/thread_annotations.h" 33 34 namespace tensorflow { 35 namespace serving { 36 37 // Base class for resource that encapsulating the state and logic for batching 38 // tensors. 39 class BatchResourceBase : public ResourceBase { 40 public: 41 // Given a BatchTask (from one op invocation) with 'num_outputs'== M and 42 // splitted into N sub tasks, TensorMatrix is a N X M matrix. 43 // Namely, TensorMatrix[i][j] indicates the i-th split tensor of j-th output; 44 // concatenating tensors along the 2nd dimension gives a output tensor. 45 typedef std::vector<std::vector<Tensor>> TensorMatrix; 46 47 // Ingests data from one invocation of the batch op. The data is enqueued to 48 // be combined with others into a batch, asynchronously. 49 Status RegisterInput(int64 guid, OpKernelContext* context, 50 const string& batcher_queue_name, 51 AsyncOpKernel::DoneCallback done_callback); 52 53 public: 54 // One task to be batched, corresponds to a `slice` of input from one batch-op 55 // invocation. 56 // 57 // Given input from one batch-op invocation, a `slice` of this input is: 58 // 1) Split each Tensor in `BatchTask::inputs` along the 0th dimension. 59 // 2) 'split_index' is calculated along the 0-th dimension. 60 // 61 // Note input from one batch-op invocation is valid and considered a 62 // specialized `slice`. 63 struct BatchTask : public tensorflow::serving::BatchTask { 64 // A unique ID to identify this invocation of Batch. 65 int64 guid; 66 67 Context propagated_context; 68 69 std::vector<Tensor> inputs; 70 std::vector<Tensor> captured_inputs; 71 OpKernelContext* context; 72 AsyncOpKernel::DoneCallback done_callback; 73 74 // The index of this split, along the 0-th dimension of input from op 75 // invocation. 76 int split_index = 0; 77 78 // Two-dimensional tensor matrix, ownership shared by: 79 // 1) each split of task (to fill one row in this matrix) 80 // and 81 // 2) callback that runs to merge output of individual splits for an op 82 // invocation, after all splits complete. 83 std::shared_ptr<TensorMatrix> output; 84 85 // 'status' records error (could be from any split) if at least one split 86 // returns error, OK otherwise. 87 // Ownership is shared by individual splits and callback. 88 std::shared_ptr<ThreadSafeStatus> status; 89 90 bool is_partial = false; 91 92 uint64 start_time; 93 sizeBatchTask94 size_t size() const override { return inputs[0].shape().dim_size(0); } 95 96 // Create a split task from this one. The caller needs to setup the inputs 97 // of the new task 98 std::unique_ptr<BatchTask> CreateSplitTask( 99 int split_index, AsyncOpKernel::DoneCallback done_callback); 100 101 protected: CreateDerivedTaskBatchTask102 virtual std::unique_ptr<BatchTask> CreateDerivedTask() { 103 return std::make_unique<BatchTask>(); 104 } 105 }; 106 107 // Appending a T suffix to make the type alias different to those in 108 // tensorflow::serving namespace, because some versions of compiler complain 109 // about changing meaning of the symbols. 110 using BatcherT = SharedBatchScheduler<BatchResourceBase::BatchTask>; 111 using AdaptiveBatcherT = 112 AdaptiveSharedBatchScheduler<BatchResourceBase::BatchTask>; 113 using BatcherQueueT = BatchScheduler<BatchResourceBase::BatchTask>; 114 using BatchT = Batch<BatchResourceBase::BatchTask>; 115 BatchResourceBase(bool has_process_batch_function,std::shared_ptr<BatcherT> batcher,const BatcherT::QueueOptions & batcher_queue_options,std::vector<int32> allowed_batch_sizes)116 BatchResourceBase(bool has_process_batch_function, 117 std::shared_ptr<BatcherT> batcher, 118 const BatcherT::QueueOptions& batcher_queue_options, 119 std::vector<int32> allowed_batch_sizes) 120 : has_process_batch_function_(has_process_batch_function), 121 batcher_(std::move(batcher)), 122 batcher_queue_options_(batcher_queue_options), 123 allowed_batch_sizes_(std::move(allowed_batch_sizes)) { 124 allowed_batch_sizes_str_ = absl::StrJoin(allowed_batch_sizes_, ","); 125 } 126 BatchResourceBase(bool has_process_batch_function,std::shared_ptr<AdaptiveBatcherT> batcher,const AdaptiveBatcherT::QueueOptions & batcher_queue_options,std::vector<int32> allowed_batch_sizes)127 BatchResourceBase(bool has_process_batch_function, 128 std::shared_ptr<AdaptiveBatcherT> batcher, 129 const AdaptiveBatcherT::QueueOptions& batcher_queue_options, 130 std::vector<int32> allowed_batch_sizes) 131 : has_process_batch_function_(has_process_batch_function), 132 adaptive_batcher_(std::move(batcher)), 133 adaptive_batcher_queue_options_(batcher_queue_options), 134 allowed_batch_sizes_(std::move(allowed_batch_sizes)) {} 135 136 static BatcherT::QueueOptions GetBatcherQueueOptions( 137 int32 num_batch_threads, int32 max_batch_size, int32 batch_timeout_micros, 138 int32 max_enqueued_batches, const std::vector<int32>& allowed_batch_sizes, 139 bool enable_large_batch_splitting); 140 141 static AdaptiveBatcherT::QueueOptions GetAdaptiveBatcherQueueOptions( 142 int32 max_batch_size, int32 batch_timeout_micros, 143 int32 max_enqueued_batches, bool enable_large_batch_splitting, 144 const std::vector<int32>& allowed_batch_sizes); 145 146 private: 147 // Implementation of calling the process batch function. 148 virtual void ProcessFuncBatchImpl( 149 const BatchResourceBase::BatchTask& last_task, 150 absl::Span<const Tensor> inputs, std::vector<Tensor>* combined_outputs, 151 std::function<void(const Status&)> done) const = 0; 152 153 // Factory method for creating a BatchTask, overridable by subclasses. 154 virtual Status CreateBatchTask( 155 OpKernelContext* context, 156 std::unique_ptr<BatchResourceBase::BatchTask>* output) const; 157 158 // Validates that it's legal to combine the tasks in 'batch' into a batch. 159 // Assumes the batch is non-empty. 160 static Status ValidateBatch(const BatchT& batch); 161 162 // Returns the smallest entry in 'allowed_batch_sizes_' that is greater than 163 // or equal to 'batch_size'. If 'allowed_batch_sizes_' is empty, simply 164 // returns 'batch_size'. 165 int RoundToLowestAllowedBatchSize(int batch_size) const; 166 167 Status ConcatInputTensors(const BatchT& batch, OpKernelContext* context, 168 std::vector<Tensor>* concatenated_tensors) const; 169 170 // Split 'input' of 'input_task_ptr' along 0th dimension, into a list of 171 // 'output_tasks'. 172 // Task sizes are determined by 173 // 1) open_batch_remaining_slot 174 // 2) max_batch_size 175 // 3) size-of-input-task 176 // in a way that 177 // 1) Task sizes add up to `size-of-input-task`. 178 // 2) Task sizes from left to right are like 179 // [open_batch_remaining_slot, max_batch_size, max_batch_size, ..., 180 // `size-of-input-task` - `sum-of-previous-elements`]. 181 // 182 // REQUIRES: 183 // Caller should make sure size-of-input-task is greater than 184 // open_batch_remaining_slot. 185 static Status SplitInputTask( 186 std::unique_ptr<BatchTask>* input_task_ptr, int open_batch_remaining_slot, 187 int max_batch_size, 188 std::vector<std::unique_ptr<BatchTask>>* output_tasks); 189 190 Status SplitOutputTensors(const std::vector<Tensor>& combined_outputs, 191 BatchT* batch) const; 192 193 void ProcessFuncBatch(std::unique_ptr<BatchT> batch) const; 194 195 // Processes a batch of one or more BatchTask entries. 196 void ProcessBatch(std::unique_ptr<BatchT> batch) const; 197 198 // Emits an index tensor, which the Unbatch op will use to un-concatenate 199 // the tensor and attribute the pieces to the right batch keys. The index 200 // tensor contains, for each input: [batch_key, start_offset, end_offset] 201 // where start_offset and end_offset represent the range of entries in the 202 // concatenated tensors that belong to that input. 203 // 204 // Emits the result to the output at 'output_index' using 'context'. 205 static Status EmitIndexTensor(OpKernelContext* context, const BatchT& batch, 206 int output_index); 207 208 // Looks up the batcher queue for 'queue_name'. If it did't previously exist, 209 // creates it. 210 Status LookupOrCreateBatcherQueue(const string& queue_name, 211 BatcherQueueT** queue); 212 213 // True if user specified a batch processing function for this resource. 214 const bool has_process_batch_function_; 215 // A batch scheduler, and options for creating queues. 216 std::shared_ptr<BatcherT> batcher_; 217 BatcherT::QueueOptions batcher_queue_options_; 218 219 // A batch scheduler, and options for creating queues. 220 std::shared_ptr<AdaptiveBatcherT> adaptive_batcher_; 221 AdaptiveBatcherT::QueueOptions adaptive_batcher_queue_options_; 222 223 // A collection of batcher queues, keyed on queue name. 224 // TODO(olston): Garbage-collect unused queues (perhaps simply remove empty 225 // ones (with a time delay?); it's okay if they get recreated later). 226 mutable mutex batcher_queues_mu_; 227 std::map<string, std::unique_ptr<BatcherQueueT>> batcher_queues_ 228 TF_GUARDED_BY(batcher_queues_mu_); 229 230 std::vector<int32> allowed_batch_sizes_; 231 // A concatenated string of <allowed_batch_sizes_>, separated by ",". This is 232 // used to record batching parameter. 233 string allowed_batch_sizes_str_; 234 }; 235 236 } // namespace serving 237 } // namespace tensorflow 238 239 #endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_RESOURCE_BASE_H_ 240