1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#import "tensorflow/lite/delegates/gpu/metal_delegate.h" 17 18#import <Metal/Metal.h> 19 20#include <algorithm> 21#include <cstring> 22#include <map> 23#include <memory> 24#include <mutex> 25#include <string> 26#include <thread> 27#include <vector> 28 29#include "absl/container/flat_hash_set.h" 30#include "absl/types/span.h" 31#include "tensorflow/lite/builtin_ops.h" 32#include "tensorflow/lite/c/common.h" 33#include "tensorflow/lite/context_util.h" 34#include "tensorflow/lite/delegates/gpu/common/convert.h" 35#include "tensorflow/lite/delegates/gpu/common/model.h" 36#include "tensorflow/lite/delegates/gpu/common/model_builder.h" 37#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" 38#include "tensorflow/lite/delegates/gpu/common/quantization_util.h" 39#include "tensorflow/lite/delegates/gpu/common/shape.h" 40#include "tensorflow/lite/delegates/gpu/common/status.h" 41#include "tensorflow/lite/delegates/gpu/common/types.h" 42#include "tensorflow/lite/delegates/gpu/metal/buffer_convert.h" 43#include "tensorflow/lite/delegates/gpu/metal/common.h" 44#include "tensorflow/lite/delegates/gpu/common/gpu_info.h" 45#include "tensorflow/lite/delegates/gpu/metal/inference_context.h" 46#include "tensorflow/lite/delegates/gpu/common/precision.h" 47#include "tensorflow/lite/kernels/kernel_util.h" 48#include "tensorflow/lite/minimal_logging.h" 49 50 51namespace tflite { 52namespace gpu { 53namespace metal { 54namespace { 55 56// Multi-thread safe alarm clock for preventing GPU sleeping. It spawns lightweight compute tasks 57// until no inference is performing on a device. It's reduces the CPU-to-CPU inference latency. 58// The class is used only for kAggressive wait type. 59class GpuAlarmClock { 60 public: 61 explicit GpuAlarmClock(id<MTLCommandQueue> command_queue) { 62 auto device = [command_queue device]; 63 std::lock_guard<std::mutex> lock(alarms_mutex_); 64 if (!alarms_) alarms_ = new std::map<id<MTLDevice>, GpuAlarmClockInternal*>(); 65 auto it = alarms_->find(device); 66 if (it == alarms_->end()) { 67 internal_ = new GpuAlarmClockInternal(command_queue); 68 (*alarms_)[device] = internal_; 69 } else { 70 internal_ = it->second; 71 internal_->total_alarms_++; 72 } 73 } 74 ~GpuAlarmClock() { 75 std::lock_guard<std::mutex> lock(alarms_mutex_); 76 if (--internal_->total_alarms_ > 0) return; 77 Stop(); 78 delete internal_; 79 // Remove the alarm from the container to free-up device handle. 80 for (auto it = alarms_->begin(); it != alarms_->end(); ++it) { 81 if (it->second == internal_) { 82 alarms_->erase(it); 83 break; 84 } 85 } 86 if (alarms_->empty()) { 87 delete alarms_; 88 alarms_ = nullptr; 89 } 90 } 91 void Start() { 92 if (started_) return; 93 started_ = true; 94 internal_->active_alarms_++; 95 } 96 void Stop() { 97 if (!started_) return; 98 started_ = false; 99 internal_->active_alarms_--; 100 } 101 102 private: 103 class GpuAlarmClockInternal { 104 public: 105 id<MTLComputePipelineState> stub_program_; 106 id<MTLBuffer> stub_buffer_; 107 explicit GpuAlarmClockInternal(id<MTLCommandQueue> command_queue) { 108 command_queue_ = command_queue; 109 device_ = [command_queue_ device]; 110 total_alarms_ = 1; 111 NSString* error; 112 id<MTLComputePipelineState> program; 113 // TODO(impjdi): Properly handle returned status. 114 CreateComputeProgram(device_, 115 @"kernel void ComputeFunction(device int* output_buffer [[buffer(0)]]) " 116 @"{ output_buffer[0] = 0; }", 117 @"ComputeFunction", nullptr, &program) 118 .IgnoreError(); 119 stub_program_ = program; 120 stub_buffer_ = [device_ newBufferWithLength:sizeof(int) * 4 121 options:MTLResourceHazardTrackingModeUntracked]; 122 alarm_thread_ = std::thread([this]() { 123 id<MTLCommandBuffer> prev_command_buffer; 124 while (!release_thread_) { 125 @autoreleasepool { 126 if (active_alarms_ == total_alarms_) { 127 id<MTLCommandBuffer> command_buffer = [command_queue_ commandBuffer]; 128 id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder]; 129 [encoder setComputePipelineState:stub_program_]; 130 [encoder setBuffer:stub_buffer_ offset:0 atIndex:0]; 131 [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) 132 threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; 133 [encoder endEncoding]; 134 [command_buffer commit]; 135 if (prev_command_buffer != nil) [prev_command_buffer waitUntilScheduled]; 136 prev_command_buffer = command_buffer; 137 } else { 138 std::this_thread::sleep_for(std::chrono::milliseconds(1)); 139 } 140 } 141 } 142 }); 143 } 144 ~GpuAlarmClockInternal() { 145 release_thread_ = true; 146 alarm_thread_.join(); 147 } 148 149 private: 150 friend class GpuAlarmClock; 151 std::atomic<int> active_alarms_; 152 std::thread alarm_thread_; 153 id<MTLCommandQueue> command_queue_; 154 id<MTLDevice> device_; 155 volatile bool release_thread_ = false; 156 int total_alarms_ = 0; 157 }; 158 static std::map<id<MTLDevice>, GpuAlarmClockInternal*>* alarms_; 159 std::mutex alarms_mutex_; 160 GpuAlarmClockInternal* internal_; 161 bool started_ = false; 162}; 163std::map<id<MTLDevice>, GpuAlarmClock::GpuAlarmClockInternal*>* GpuAlarmClock::alarms_ = nullptr; 164 165// Forward declaration. 166TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate); 167 168class Delegate { 169 struct ValueRef { 170 BHWC shape; 171 int64_t tensor_id; 172 }; 173 174 public: 175 explicit Delegate(const TFLGpuDelegateOptions* options) { 176 if (options) { 177 options_ = *options; 178 } else { 179 options_ = TFLGpuDelegateOptionsDefault(); 180 } 181 metal_device_ = MTLCreateSystemDefaultDevice(); 182 command_queue_ = [metal_device_ newCommandQueue]; 183 if (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeAggressive) { 184 gpu_alarm_clock_ = std::unique_ptr<GpuAlarmClock>(new GpuAlarmClock(command_queue_)); 185 NSString* code = @R"( 186 kernel void ComputeFunction(device int* output_buffer [[buffer(0)]], 187 constant int& value [[buffer(1)]]) { 188 output_buffer[0] = value; 189 } 190 )"; 191 NSString* error; 192 id<MTLComputePipelineState> signal_program; 193 // TODO(impjdi): Properly handle returned status. 194 CreateComputeProgram(metal_device_, code, @"ComputeFunction", nullptr, &signal_program) 195 .IgnoreError(); 196 signal_program_ = signal_program; 197 signal_buffer_ = [metal_device_ newBufferWithLength:sizeof(int) * 4 198 options:MTLResourceStorageModeShared | 199 MTLResourceHazardTrackingModeUntracked]; 200 } 201 } 202 203 absl::Status BindBufferToTensor(id<MTLBuffer> buffer, int tensor_index) { 204 // The tensor index is expected to be an input or output tensor of the interpreter. 205 // For quantized model, the buffer should be linked with their dequantized counterpart. 206 if (quant_conversion_map_.find(tensor_index) != quant_conversion_map_.end()) { 207 tensor_index = quant_conversion_map_[tensor_index]; 208 // remove [dequantized tensor ID] -> [quantized tensor ID] mapping, to prevent extra 209 // dequant/quant on in/outputs. 210 quant_conversion_map_.erase(tensor_index); 211 } 212 for (auto& input : graph_inputs_) { 213 if (input.tensor_id == tensor_index) { 214 input_output_buffers_[input.id] = buffer; 215 if (bphwc4_buffers_[input.id] != buffer) { 216 bphwc_buffers_updated_ = true; 217 } 218 bphwc4_buffers_[input.id] = buffer; 219 input.set_externally = true; 220 return absl::OkStatus(); 221 } 222 } 223 for (auto& output : graph_outputs_) { 224 if (output.tensor_id == tensor_index) { 225 input_output_buffers_[output.id] = buffer; 226 if (bphwc4_buffers_[output.id] != buffer) { 227 bphwc_buffers_updated_ = true; 228 } 229 bphwc4_buffers_[output.id] = buffer; 230 output.set_externally = true; 231 return absl::OkStatus(); 232 } 233 } 234 return absl::NotFoundError("Couldn't find tensor: " + std::to_string(tensor_index)); 235 } 236 237 void SetCommandBuffer(id<MTLCommandBuffer> command_buffer) { 238 external_command_buffer_ = command_buffer; 239 } 240 241 // This directs the runtime to allocate memory for input/output temporary 242 // tensors that require dequantization/quantization. 243 absl::Status GetRequiredTemporaries(TfLiteContext* context, TfLiteNode* node, 244 TfLiteIntArray** temporaries_array_ptr) { 245 if (quant_conversion_map_.empty()) return absl::OkStatus(); 246 247 std::vector<int> temporary_tensor_ids; 248 for (auto index : input_tensor_ids_) { 249 if (quant_conversion_map_.find(index) != quant_conversion_map_.end()) { 250 temporary_tensor_ids.push_back(index); 251 } 252 } 253 for (auto index : output_tensor_ids_) { 254 if (quant_conversion_map_.find(index) != quant_conversion_map_.end()) { 255 temporary_tensor_ids.push_back(index); 256 } 257 } 258 *temporaries_array_ptr = TfLiteIntArrayCreate(temporary_tensor_ids.size()); 259 for (int i = 0; i < temporary_tensor_ids.size(); ++i) { 260 (*temporaries_array_ptr)->data[i] = temporary_tensor_ids[i]; 261 } 262 return absl::OkStatus(); 263 } 264 265 absl::Status Prepare(TfLiteContext* context, const TfLiteDelegateParams* delegate_params) { 266 // Extract TFLite delegate execution plan from the context and convert it into GraphFloat32. 267 GraphFloat32 graph; 268 quant_conversion_map_.clear(); 269 if (options_.enable_quantization) { 270 RETURN_IF_ERROR(BuildFinalModel(context, delegate_params, &graph, &quant_conversion_map_)); 271 } else { 272 RETURN_IF_ERROR(BuildFinalModel(context, delegate_params, &graph)); 273 } 274 275 // TODO(impjdi): Remove code duplication. 276 auto values = graph.values(); 277 auto find_value = [&](int tensor_index) -> Value* { 278 for (auto value : values) { 279 if (value->tensor.ref == tensor_index) return value; 280 } 281 return nullptr; 282 }; 283 tensors_.reserve(values.back()->id + 1); 284 for (const auto* value : values) { 285 if (tensors_.size() <= value->id) tensors_.resize(value->id + 1); 286 tensors_[value->id] = { 287 value->tensor.shape, // .shape 288 value->tensor.ref, // .tensor_id 289 }; 290 } 291 292 // Prepare graph inputs. 293 // 294 // Note that graph.inputs() cannot be used directly, as the notion of graph input has a 295 // different meaning in public API and GPU-internal API. 296 for (int tensor_index : TfLiteIntArrayView(delegate_params->input_tensors)) { 297 auto* tensor = &context->tensors[tensor_index]; 298 if (IsConstantTensor(tensor)) continue; 299 // For quantized models, actual inputs of GPU graph are float tensors, so the 8-bit inputs 300 // to the delegate kernel need to be dequantized berfore feeding to the GPU graph. 301 if (options_.enable_quantization && 302 quant_conversion_map_.find(tensor_index) != quant_conversion_map_.end()) { 303 tensor_index = quant_conversion_map_[tensor_index]; 304 tensor = &context->tensors[tensor_index]; 305 } 306 const auto* input = find_value(tensor_index); 307 if (!input || tensor->type != TfLiteType::kTfLiteFloat32) { 308 return absl::NotFoundError("Input tensor is not found in the graph."); 309 } 310 311 inputs_.push_back(input->id); 312 input_tensor_ids_.push_back(tensor_index); 313 tensor->buffer_handle = input->id; 314 tensor->delegate = &delegate_; 315 } 316 317 // Prepare graph outputs. 318 // 319 // Note that graph.outputs() cannot be used directly, as the notion of graph output has a 320 // different meaning in public API and GPU-internal API. 321 for (int tensor_index : TfLiteIntArrayView(delegate_params->output_tensors)) { 322 auto* tensor = &context->tensors[tensor_index]; 323 if (IsConstantTensor(tensor)) continue; 324 // For quantized models, actual outputs of GPU graph are float tensors, so they should be 325 // quantized to be the 8-bit outputs of delegate. 326 if (options_.enable_quantization && 327 quant_conversion_map_.find(tensor_index) != quant_conversion_map_.end()) { 328 tensor_index = quant_conversion_map_[tensor_index]; 329 tensor = &context->tensors[tensor_index]; 330 } 331 const auto* output = find_value(tensor_index); 332 if (!output || tensor->type != TfLiteType::kTfLiteFloat32) { 333 return absl::NotFoundError("Output tensor is not found in the graph."); 334 } 335 336 outputs_.push_back(output->id); 337 output_tensor_ids_.push_back(tensor_index); 338 tensor->buffer_handle = output->id; 339 tensor->delegate = &delegate_; 340 } 341 342 std::string device_name = std::string([[metal_device_ name] UTF8String]); 343 GpuInfo gpu_info; 344 GetGpuInfoFromDeviceDescription(device_name, GpuApi::kMetal, &gpu_info); 345 size_t storage_type_size; 346 CalculationsPrecision precision; 347 if (options_.allow_precision_loss) { 348 storage_type_size = sizeof(HalfBits); 349 if (gpu_info.IsRoundToNearestSupported()) { 350 precision = CalculationsPrecision::F16; 351 } else { 352 precision = CalculationsPrecision::F32_F16; 353 } 354 } else { 355 storage_type_size = sizeof(float); 356 precision = CalculationsPrecision::F32; 357 } 358 359 // TODO(impjdi): Merge logic with above. 360 // Pre-allocate input and output metal buffers 361 std::vector<::tflite::gpu::ValueId> input_ids; 362 input_ids.reserve(inputs_.size()); 363 std::map<::tflite::gpu::ValueId, BHWC> input_dimensions; 364 graph_inputs_.reserve(inputs_.size()); 365 for (const ValueId input : inputs_) { 366 const auto& input_tensor = tensors_[input]; 367 const auto tensor_id = input_tensor.tensor_id; 368 input_ids.push_back(input); 369 if (input_tensor.shape.b != 1) { 370 return absl::UnimplementedError("Batching is not supported yet."); 371 } 372 input_dimensions[input] = input_tensor.shape; 373 graph_inputs_.push_back({ 374 input, // .id 375 tensor_id, // .tensor_id 376 input_tensor.shape, // .shape 377 false, // .set_externally 378 }); 379 int bhwc_length = static_cast<int>(sizeof(float) * input_tensor.shape.DimensionsProduct()); 380 int bphwc4_length = 381 static_cast<int>(storage_type_size * GetElementsSizeForPHWC4(input_tensor.shape)); 382 id<MTLBuffer> buffer = [metal_device_ newBufferWithLength:bhwc_length 383 options:MTLResourceStorageModeShared]; 384 input_output_buffers_[input] = buffer; 385 if (options_.allow_precision_loss || input_tensor.shape.c != 4) { 386 bphwc4_buffers_[input] = [metal_device_ newBufferWithLength:bphwc4_length 387 options:MTLResourceStorageModeShared]; 388 if (converter_to_BPHWC4_ == nil) { 389 converter_to_BPHWC4_ = 390 [[TFLBufferConvert alloc] initWithDevice:metal_device_ 391 isFloat16:options_.allow_precision_loss 392 convertToPBHWC4:true]; 393 if (converter_to_BPHWC4_ == nil) { 394 return absl::InternalError("Error initialization of input buffer converter"); 395 } 396 } 397 } else { 398 bphwc4_buffers_[input] = buffer; 399 } 400 } 401 402 std::vector<::tflite::gpu::ValueId> output_ids; 403 output_ids.reserve(outputs_.size()); 404 graph_outputs_.reserve(outputs_.size()); 405 for (const ValueId output : outputs_) { 406 const auto& output_tensor = tensors_[output]; 407 const auto tensor_id = output_tensor.tensor_id; 408 output_ids.push_back(output); 409 graph_outputs_.push_back({ 410 output, // .id 411 tensor_id, // .tensor_id 412 output_tensor.shape, // .shape 413 false, // .set_externally 414 }); 415 // Create BHWC buffer 416 int bhwc_length = static_cast<int>(sizeof(float) * output_tensor.shape.DimensionsProduct()); 417 int bphwc4_length = 418 static_cast<int>(storage_type_size * GetElementsSizeForPHWC4(output_tensor.shape)); 419 id<MTLBuffer> buffer = [metal_device_ newBufferWithLength:bhwc_length 420 options:MTLResourceStorageModeShared]; 421 input_output_buffers_[output] = buffer; 422 if (options_.allow_precision_loss || output_tensor.shape.c != 4) { 423 bphwc4_buffers_[output] = [metal_device_ newBufferWithLength:bphwc4_length 424 options:MTLResourceStorageModeShared]; 425 if (converter_from_BPHWC4_ == nil) { 426 converter_from_BPHWC4_ = 427 [[TFLBufferConvert alloc] initWithDevice:metal_device_ 428 isFloat16:options_.allow_precision_loss 429 convertToPBHWC4:false]; 430 if (converter_from_BPHWC4_ == nil) { 431 return absl::InternalError("Error initialization of output buffer converter"); 432 } 433 } 434 } else { 435 bphwc4_buffers_[output] = buffer; 436 } 437 } 438 bphwc_buffers_updated_ = true; 439 440 InferenceContext::CreateInferenceInfo create_info; 441 create_info.precision = precision; 442 create_info.storage_type = TensorStorageType::BUFFER; 443 RETURN_IF_ERROR( 444 inference_context_.InitFromGraphWithTransforms(create_info, &graph, metal_device_)); 445 return absl::OkStatus(); 446 } 447 448 absl::Status Invoke(TfLiteContext* context) { 449 if (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeAggressive) 450 gpu_alarm_clock_->Stop(); 451 // We need only synchronization so volatile works better than atomic which reads from global 452 // memory each time. 453 __block volatile bool buffer_completed = false; 454 id<MTLCommandBuffer> command_buffer = external_command_buffer_; 455 if (external_command_buffer_ == nil) { 456 command_buffer = [command_queue_ commandBuffer]; 457 } 458 const bool flush = external_command_buffer_ == nil && 459 (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeActive || 460 options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeAggressive); 461 const int flush_period = 8; 462 463 const bool is_quantized_model = !quant_conversion_map_.empty(); 464 if (is_quantized_model) { 465 RETURN_IF_ERROR(DequantizeInputs(context, input_tensor_ids_, quant_conversion_map_)); 466 } 467 468 // CPU HWC input data conversion to PHWC4 and fill the GPU buffer 469 for (const auto& input : graph_inputs_) { 470 if (input.set_externally) continue; 471 // A user provides data on CPU memory for this buffer - need to copy to MTLBuffer 472 473 TfLiteTensor* tensor = &context->tensors[input.tensor_id]; 474 void* gpu_ptr = [input_output_buffers_[input.id] contents]; 475 std::memcpy(gpu_ptr, tensor->data.f, input.shape.DimensionsProduct() * sizeof(float)); 476 if (input_output_buffers_[input.id] == bphwc4_buffers_[input.id]) continue; 477 id<MTLComputeCommandEncoder> input_encoder = [command_buffer computeCommandEncoder]; 478 [converter_to_BPHWC4_ convertWithEncoder:input_encoder 479 shape:input.shape 480 sourceBuffer:input_output_buffers_[input.id] 481 convertedBuffer:bphwc4_buffers_[input.id]]; 482 [input_encoder endEncoding]; 483 } 484 485 if (bphwc_buffers_updated_) { 486 inference_context_.UpdatePreallocatedTensors(bphwc4_buffers_); 487 bphwc_buffers_updated_ = false; 488 } 489 490 @autoreleasepool { 491 if (flush) { 492 [command_buffer commit]; 493 inference_context_.EncodeWithCommandQueue(command_queue_, flush_period); 494 command_buffer = [command_queue_ commandBuffer]; 495 } else { 496 inference_context_.EncodeWithCommandBuffer(command_buffer); 497 } 498 } 499 500 for (const auto& output : graph_outputs_) { 501 if (output.set_externally) continue; 502 if (bphwc4_buffers_[output.id] == input_output_buffers_[output.id]) continue; 503 id<MTLComputeCommandEncoder> output_encoder = [command_buffer computeCommandEncoder]; 504 [converter_from_BPHWC4_ convertWithEncoder:output_encoder 505 shape:output.shape 506 sourceBuffer:bphwc4_buffers_[output.id] 507 convertedBuffer:input_output_buffers_[output.id]]; 508 [output_encoder endEncoding]; 509 } 510 511 if (external_command_buffer_ == nil) { 512 if (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeActive) { 513 [command_buffer addCompletedHandler:^(id<MTLCommandBuffer>) { 514 buffer_completed = true; 515 }]; 516 } 517 [command_buffer commit]; 518 if (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeActive) { 519 while (!buffer_completed) { 520 // Busy wait. Use local variable. Volatile uses RAM access all the time. 521 for (volatile int i = 0; i < 100; i++) { 522 } 523 } 524 } else if (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypePassive) { 525 // passive wait: this thread sleeps until GPU finishes. 526 [command_buffer waitUntilCompleted]; 527 } else if (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeAggressive) { 528 id<MTLCommandBuffer> signal_cb = [command_queue_ commandBuffer]; 529 id<MTLComputeCommandEncoder> signal_encoder = [signal_cb computeCommandEncoder]; 530 [signal_encoder setComputePipelineState:signal_program_]; 531 [signal_encoder setBuffer:signal_buffer_ offset:0 atIndex:0]; 532 signal_value_++; 533 [signal_encoder setBytes:&signal_value_ length:sizeof(int) atIndex:1]; 534 [signal_encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) 535 threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; 536 [signal_encoder endEncoding]; 537 [signal_cb commit]; 538 gpu_alarm_clock_->Start(); 539 const int* signal_ptr = reinterpret_cast<const int*>([signal_buffer_ contents]); 540 while (signal_ptr[0] != signal_value_) { 541 // Busy wait. Spinning with local variable to avoid RAM pressure. 542 for (volatile int i = 0; i < 100; i++) { 543 } 544 } 545 } 546 } else { 547 // External command buffer must be set before every invoke call. 548 external_command_buffer_ = nil; 549 // External command buffer is assigned so all output buffers are controlled by a user. 550 for (const auto& output : graph_outputs_) { 551 if (!output.set_externally) { 552 return absl::InternalError( 553 "External command encoder is used, but not all output buffers are bound."); 554 } 555 } 556 return absl::OkStatus(); 557 } 558 559 // Retrieve data from GPU and convert from PHWC4 to HWC. 560 for (const auto& output : graph_outputs_) { 561 if (output.set_externally) continue; 562 // A user retrieves data on CPU memory for this buffer - need to copy from MTLBuffer. 563 TfLiteTensor* tensor = context->tensors + output.tensor_id; 564 const void* gpu_ptr = [input_output_buffers_[output.id] contents]; 565 std::memcpy(tensor->data.f, gpu_ptr, output.shape.DimensionsProduct() * sizeof(float)); 566 } 567 if (is_quantized_model) { 568 RETURN_IF_ERROR(QuantizeOutputs(context, output_tensor_ids_, quant_conversion_map_)); 569 } 570 return absl::OkStatus(); 571 } 572 573 const TFLGpuDelegateOptions options() const { return options_; } 574 575 TfLiteDelegate* tflite_delegate() { return &delegate_; } 576 577 private: 578 TfLiteDelegate delegate_ = { 579 reinterpret_cast<void*>(this), // .data_ 580 DelegatePrepare, // .Prepare 581 nullptr, // .CopyFromBufferHandle 582 nullptr, // .CopyToBufferHandle 583 nullptr, // .FreeBufferHandle 584 kTfLiteDelegateFlagsNone, // .flags 585 }; 586 587 TFLGpuDelegateOptions options_; 588 589 id<MTLDevice> metal_device_; 590 591 std::vector<ValueRef> tensors_; // indexed by ValueId 592 std::vector<ValueId> inputs_; 593 std::vector<ValueId> outputs_; 594 std::vector<int64_t> input_tensor_ids_; 595 std::vector<int64_t> output_tensor_ids_; 596 // Whenever quantized inference is enabled, this maps the tensor index of each 597 // originally quantized (8-bit) tensor to its float version added in 598 // model_builder - and vice versa. 599 absl::flat_hash_map<int, int> quant_conversion_map_; 600 601 InferenceContext inference_context_; 602 // input and output buffers are passed into Metal inference engine 603 std::map<::tflite::gpu::ValueId, id<MTLBuffer>> input_output_buffers_; 604 std::map<::tflite::gpu::ValueId, id<MTLBuffer>> bphwc4_buffers_; 605 bool bphwc_buffers_updated_ = true; 606 TFLBufferConvert* converter_to_BPHWC4_ = nil; 607 TFLBufferConvert* converter_from_BPHWC4_ = nil; 608 609 struct BufferDescriptor { 610 ValueId id; 611 int64_t tensor_id; 612 BHWC shape; 613 bool set_externally; // a user fills/retrieves data on this MTLBuffer buffer 614 }; 615 std::vector<BufferDescriptor> graph_inputs_; 616 std::vector<BufferDescriptor> graph_outputs_; 617 618 id<MTLCommandBuffer> external_command_buffer_ = nil; 619 id<MTLCommandQueue> command_queue_; 620 std::unique_ptr<GpuAlarmClock> gpu_alarm_clock_; 621 id<MTLComputePipelineState> signal_program_; 622 id<MTLBuffer> signal_buffer_; 623 int signal_value_ = 0; 624}; 625 626Delegate* GetMetalDelegate(TfLiteNode* node) { 627 return reinterpret_cast<Delegate*>(node->user_data); 628} 629 630Delegate* GetMetalDelegate(TfLiteDelegate* delegate) { 631 return reinterpret_cast<Delegate*>(delegate->data_); 632} 633 634TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) { 635 const TfLiteRegistration kRegistration = { 636 // .init 637 [](TfLiteContext* context, const char* buffer, size_t) -> void* { 638 const auto* params = reinterpret_cast<const TfLiteDelegateParams*>(buffer); 639 auto* metal_delegate = GetMetalDelegate(params->delegate); 640 // Everything below should happen in prepare function call, but TFLite for whatever reason 641 // forbids that. 642 const auto status = metal_delegate->Prepare(context, params); 643 if (status.ok()) return metal_delegate; 644 TF_LITE_KERNEL_LOG(context, "TfLiteMetalDelegate Prepare: %s", 645 std::string(status.message()).c_str()); 646 return nullptr; 647 }, 648 // .free 649 [](TfLiteContext*, void* buffer) -> void {}, 650 // .prepare 651 [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus { 652 if (!node->user_data) { 653 return kTfLiteError; 654 } 655 656 auto* gpu_delegate_kernel = GetMetalDelegate(node); 657 const auto status = 658 gpu_delegate_kernel->GetRequiredTemporaries(context, node, &node->temporaries); 659 if (!status.ok()) { 660 TF_LITE_KERNEL_LOG(context, "TfLiteMetalDelegate Prepare: %s", 661 std::string(status.message()).c_str()); 662 return kTfLiteError; 663 } 664 return node->user_data ? kTfLiteOk : kTfLiteError; 665 }, 666 // .invoke 667 [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus { 668 const auto status = GetMetalDelegate(node)->Invoke(context); 669 if (status.ok()) return kTfLiteOk; 670 TF_LITE_KERNEL_LOG(context, "TfLiteMetalDelegate Invoke: %s", 671 std::string(status.message()).c_str()); 672 return kTfLiteError; 673 }, 674 nullptr, // .profiling_string 675 0, // .builtin_code 676 "TfLiteMetalDelegate", // .custom_name 677 1, // .version 678 }; 679 TfLiteIntArray* ops_to_replace = 680 GetOpsToReplace(context, GetMetalDelegate(delegate)->options().enable_quantization); 681 const auto status = context->ReplaceNodeSubsetsWithDelegateKernels(context, kRegistration, 682 ops_to_replace, delegate); 683 TfLiteIntArrayFree(ops_to_replace); 684 return status; 685} 686 687} // namespace 688} // namespace metal 689} // namespace gpu 690} // namespace tflite 691 692TfLiteDelegate* TFLGpuDelegateCreate(const TFLGpuDelegateOptions* options) { 693 TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO, "Created TensorFlow Lite delegate for Metal."); 694 auto* metal_delegate = new ::tflite::gpu::metal::Delegate(options); 695 return metal_delegate ? metal_delegate->tflite_delegate() : nullptr; 696} 697 698void TFLGpuDelegateDelete(TfLiteDelegate* delegate) { 699 delete ::tflite::gpu::metal::GetMetalDelegate(delegate); 700} 701 702bool TFLGpuDelegateBindMetalBufferToTensor(TfLiteDelegate* delegate, int tensor_index, 703 id<MTLBuffer> buffer) { 704 auto* metal_delegate = ::tflite::gpu::metal::GetMetalDelegate(delegate); 705 return metal_delegate && metal_delegate->BindBufferToTensor(buffer, tensor_index).ok(); 706} 707 708// Note: This function is not exposed in `metal_delegate.h`, but it's exposed in 709// `metal_delegate_internal.h`. 710bool TFLGpuDelegateSetCommandBuffer(TfLiteDelegate* delegate, 711 id<MTLCommandBuffer> command_buffer) { 712 auto* metal_delegate = ::tflite::gpu::metal::GetMetalDelegate(delegate); 713 if (!metal_delegate) return false; 714 metal_delegate->SetCommandBuffer(command_buffer); 715 return true; 716} 717 718TFLGpuDelegateOptions TFLGpuDelegateOptionsDefault() { 719 TFLGpuDelegateOptions options = { 720 .allow_precision_loss = false, 721 .wait_type = TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypePassive, 722 .enable_quantization = true, 723 }; 724 return options; 725} 726