1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_RESOURCE_BASE_H_ 17 #define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_RESOURCE_BASE_H_ 18 19 #include <map> 20 21 #include "absl/strings/str_join.h" 22 #include "tensorflow/core/common_runtime/cost_measurement_registry.h" 23 #include "tensorflow/core/common_runtime/request_cost.h" 24 #include "tensorflow/core/framework/op_kernel.h" 25 #include "tensorflow/core/framework/resource_mgr.h" 26 #include "tensorflow/core/framework/tensor.h" 27 #include "tensorflow/core/framework/tensor_shape.h" 28 #include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h" 29 #include "tensorflow/core/kernels/batching_util/batch_scheduler.h" 30 #include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h" 31 #include "tensorflow/core/kernels/batching_util/threadsafe_status.h" 32 #include "tensorflow/core/platform/context.h" 33 #include "tensorflow/core/platform/status.h" 34 #include "tensorflow/core/platform/thread_annotations.h" 35 36 namespace tensorflow { 37 namespace serving { 38 39 // Base class for resource that encapsulating the state and logic for batching 40 // tensors. 41 class BatchResourceBase : public ResourceBase { 42 public: 43 // Given a BatchTask (from one op invocation) with 'num_outputs'== M and 44 // splitted into N sub tasks, TensorMatrix is a N X M matrix. 45 // Namely, TensorMatrix[i][j] indicates the i-th split tensor of j-th output; 46 // concatenating tensors along the 2nd dimension gives a output tensor. 47 typedef std::vector<std::vector<Tensor>> TensorMatrix; 48 49 // Ingests data from one invocation of the batch op. The data is enqueued to 50 // be combined with others into a batch, asynchronously. 51 Status RegisterInput(int64_t guid, OpKernelContext* context, 52 const string& batcher_queue_name, 53 AsyncOpKernel::DoneCallback done_callback); 54 55 public: 56 // One task to be batched, corresponds to a `slice` of input from one batch-op 57 // invocation. 58 // 59 // Given input from one batch-op invocation, a `slice` of this input is: 60 // 1) Split each Tensor in `BatchTask::inputs` along the 0th dimension. 61 // 2) 'split_index' is calculated along the 0-th dimension. 62 // 63 // Note input from one batch-op invocation is valid and considered a 64 // specialized `slice`. 65 struct BatchTask : public tensorflow::serving::BatchTask { 66 // A unique ID to identify this invocation of Batch. 67 int64 guid; 68 69 Context propagated_context; 70 71 std::vector<Tensor> inputs; 72 std::vector<Tensor> captured_inputs; 73 OpKernelContext* context; 74 AsyncOpKernel::DoneCallback done_callback; 75 76 // The index of this split, along the 0-th dimension of input from op 77 // invocation. 78 int split_index = 0; 79 80 // Two-dimensional tensor matrix, ownership shared by: 81 // 1) each split of task (to fill one row in this matrix) 82 // and 83 // 2) callback that runs to merge output of individual splits for an op 84 // invocation, after all splits complete. 85 std::shared_ptr<TensorMatrix> output; 86 87 // 'status' records error (could be from any split) if at least one split 88 // returns error, OK otherwise. 89 // Ownership is shared by individual splits and callback. 90 std::shared_ptr<ThreadSafeStatus> status; 91 92 bool is_partial = false; 93 94 uint64 start_time; 95 sizeBatchTask96 size_t size() const override { return inputs[0].shape().dim_size(0); } 97 98 // Create a split task from this one. The caller needs to setup the inputs 99 // of the new task 100 std::unique_ptr<BatchTask> CreateSplitTask( 101 int split_index, AsyncOpKernel::DoneCallback done_callback); 102 103 // RequestCost is for collecting the cost and must outlive the batching 104 // processing. 105 // 106 // For example, to collect cost in rpc processing, `request_cost` is owned 107 // by rpc handler and points to the RequestCost of an rpc which provides 108 // the inputs to this BatchTask. 109 // 110 // After the batch processing, the request cost will be incremented with 111 // this task's processing costs. 112 RequestCost* request_cost = nullptr; 113 114 protected: CreateDerivedTaskBatchTask115 virtual std::unique_ptr<BatchTask> CreateDerivedTask() { 116 return std::make_unique<BatchTask>(); 117 } 118 }; 119 120 // Appending a T suffix to make the type alias different to those in 121 // tensorflow::serving namespace, because some versions of compiler complain 122 // about changing meaning of the symbols. 123 using BatcherT = SharedBatchScheduler<BatchResourceBase::BatchTask>; 124 using AdaptiveBatcherT = 125 AdaptiveSharedBatchScheduler<BatchResourceBase::BatchTask>; 126 using BatcherQueueT = BatchScheduler<BatchResourceBase::BatchTask>; 127 using BatchT = Batch<BatchResourceBase::BatchTask>; 128 BatchResourceBase(bool has_process_batch_function,std::shared_ptr<BatcherT> batcher,const BatcherT::QueueOptions & batcher_queue_options,std::vector<int32> allowed_batch_sizes)129 BatchResourceBase(bool has_process_batch_function, 130 std::shared_ptr<BatcherT> batcher, 131 const BatcherT::QueueOptions& batcher_queue_options, 132 std::vector<int32> allowed_batch_sizes) 133 : has_process_batch_function_(has_process_batch_function), 134 batcher_(std::move(batcher)), 135 batcher_queue_options_(batcher_queue_options), 136 allowed_batch_sizes_(std::move(allowed_batch_sizes)) { 137 allowed_batch_sizes_str_ = absl::StrJoin(allowed_batch_sizes_, ","); 138 } 139 BatchResourceBase(bool has_process_batch_function,std::shared_ptr<AdaptiveBatcherT> batcher,const AdaptiveBatcherT::QueueOptions & batcher_queue_options,std::vector<int32> allowed_batch_sizes)140 BatchResourceBase(bool has_process_batch_function, 141 std::shared_ptr<AdaptiveBatcherT> batcher, 142 const AdaptiveBatcherT::QueueOptions& batcher_queue_options, 143 std::vector<int32> allowed_batch_sizes) 144 : has_process_batch_function_(has_process_batch_function), 145 adaptive_batcher_(std::move(batcher)), 146 adaptive_batcher_queue_options_(batcher_queue_options), 147 allowed_batch_sizes_(std::move(allowed_batch_sizes)) {} 148 149 static BatcherT::QueueOptions GetBatcherQueueOptions( 150 int32_t num_batch_threads, int32_t max_batch_size, 151 int32_t batch_timeout_micros, int32_t max_enqueued_batches, 152 const std::vector<int32>& allowed_batch_sizes, 153 bool enable_large_batch_splitting); 154 155 static AdaptiveBatcherT::QueueOptions GetAdaptiveBatcherQueueOptions( 156 int32_t max_batch_size, int32_t batch_timeout_micros, 157 int32_t max_enqueued_batches, bool enable_large_batch_splitting, 158 const std::vector<int32>& allowed_batch_sizes); 159 160 // Split 'input' of 'input_task_ptr' along 0th dimension, into a list of 161 // 'output_tasks'. 162 // Task sizes are determined by 163 // 1) open_batch_remaining_slot 164 // 2) max_batch_size 165 // 3) size-of-input-task 166 // in a way that 167 // 1) Task sizes add up to `size-of-input-task`. 168 // 2) Task sizes from left to right are like 169 // [open_batch_remaining_slot, max_batch_size, max_batch_size, ..., 170 // `size-of-input-task` - `sum-of-previous-elements`]. 171 // 172 // REQUIRES: 173 // Caller should make sure size-of-input-task is greater than 174 // open_batch_remaining_slot. 175 static Status SplitInputTask( 176 std::unique_ptr<BatchTask>* input_task_ptr, int open_batch_remaining_slot, 177 int max_batch_size, 178 std::vector<std::unique_ptr<BatchTask>>* output_tasks); 179 180 private: 181 // Implementation of calling the process batch function. 182 virtual void ProcessFuncBatchImpl( 183 const BatchResourceBase::BatchTask& last_task, 184 absl::Span<const Tensor> inputs, std::vector<Tensor>* combined_outputs, 185 std::function<void(const Status&)> done) const = 0; 186 187 // Factory method for creating a BatchTask, overridable by subclasses. 188 virtual Status CreateBatchTask( 189 OpKernelContext* context, 190 std::unique_ptr<BatchResourceBase::BatchTask>* output) const; 191 192 // Validates that it's legal to combine the tasks in 'batch' into a batch. 193 // Assumes the batch is non-empty. 194 static Status ValidateBatch(const BatchT& batch); 195 196 // Returns the smallest entry in 'allowed_batch_sizes_' that is greater than 197 // or equal to 'batch_size'. If 'allowed_batch_sizes_' is empty, simply 198 // returns 'batch_size'. 199 int RoundToLowestAllowedBatchSize(int batch_size) const; 200 201 Status ConcatInputTensors(const BatchT& batch, OpKernelContext* context, 202 std::vector<Tensor>* concatenated_tensors) const; 203 204 Status SplitOutputTensors(const std::vector<Tensor>& combined_outputs, 205 BatchT* batch) const; 206 207 void ProcessFuncBatch(std::unique_ptr<BatchT> batch) const; 208 209 // Processes a batch of one or more BatchTask entries. 210 void ProcessBatch(std::unique_ptr<BatchT> batch) const; 211 212 // Emits an index tensor, which the Unbatch op will use to un-concatenate 213 // the tensor and attribute the pieces to the right batch keys. The index 214 // tensor contains, for each input: [batch_key, start_offset, end_offset] 215 // where start_offset and end_offset represent the range of entries in the 216 // concatenated tensors that belong to that input. 217 // 218 // Emits the result to the output at 'output_index' using 'context'. 219 static Status EmitIndexTensor(OpKernelContext* context, const BatchT& batch, 220 int output_index); 221 222 // Looks up the batcher queue for 'queue_name'. If it did't previously exist, 223 // creates it. 224 Status LookupOrCreateBatcherQueue(const string& queue_name, 225 BatcherQueueT** queue); 226 227 // Splits the batch cost to each task. 228 // 229 // Inputs: 230 // 1) batch_cost_measurement, which provides the total cost and cost type; 231 // 2) processed_size, it's the batch size plus the padding amount; 232 // 3) batch, provides the batch size. 233 // 234 // Outputs: 235 // The request_cost in each batch task will be updated. This function will use 236 // two approaches to split the batch cost (if it's non-zero), thus two costs 237 // will be output. One approach splits the cost proportionally to the each 238 // task's size, but paddings do not share any cost, while the other splits the 239 // cost proportionally to each task or padding's size. 240 void SplitBatchCost(CostMeasurement* batch_cost_measurement, 241 const int64_t processed_size, BatchT& batch) const; 242 243 // True if user specified a batch processing function for this resource. 244 const bool has_process_batch_function_; 245 // A batch scheduler, and options for creating queues. 246 std::shared_ptr<BatcherT> batcher_; 247 BatcherT::QueueOptions batcher_queue_options_; 248 249 // A batch scheduler, and options for creating queues. 250 std::shared_ptr<AdaptiveBatcherT> adaptive_batcher_; 251 AdaptiveBatcherT::QueueOptions adaptive_batcher_queue_options_; 252 253 // A collection of batcher queues, keyed on queue name. 254 // TODO(olston): Garbage-collect unused queues (perhaps simply remove empty 255 // ones (with a time delay?); it's okay if they get recreated later). 256 mutable mutex batcher_queues_mu_; 257 std::map<string, std::unique_ptr<BatcherQueueT>> batcher_queues_ 258 TF_GUARDED_BY(batcher_queues_mu_); 259 260 std::vector<int32> allowed_batch_sizes_; 261 // A concatenated string of <allowed_batch_sizes_>, separated by ",". This is 262 // used to record batching parameter. 263 string allowed_batch_sizes_str_; 264 }; 265 266 } // namespace serving 267 } // namespace tensorflow 268 269 #endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_RESOURCE_BASE_H_ 270