• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#import "tensorflow/lite/delegates/gpu/metal_delegate.h"
17
18#import <Metal/Metal.h>
19
20#include <algorithm>
21#include <cstring>
22#include <map>
23#include <memory>
24#include <mutex>
25#include <string>
26#include <thread>
27#include <vector>
28
29#include "absl/container/flat_hash_set.h"
30#include "absl/types/span.h"
31#include "tensorflow/lite/builtin_ops.h"
32#include "tensorflow/lite/c/common.h"
33#include "tensorflow/lite/context_util.h"
34#include "tensorflow/lite/delegates/gpu/common/convert.h"
35#include "tensorflow/lite/delegates/gpu/common/model.h"
36#include "tensorflow/lite/delegates/gpu/common/model_builder.h"
37#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
38#include "tensorflow/lite/delegates/gpu/common/quantization_util.h"
39#include "tensorflow/lite/delegates/gpu/common/shape.h"
40#include "tensorflow/lite/delegates/gpu/common/status.h"
41#include "tensorflow/lite/delegates/gpu/common/types.h"
42#include "tensorflow/lite/delegates/gpu/metal/buffer_convert.h"
43#include "tensorflow/lite/delegates/gpu/metal/common.h"
44#include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
45#include "tensorflow/lite/delegates/gpu/metal/inference_context.h"
46#include "tensorflow/lite/delegates/gpu/common/precision.h"
47#include "tensorflow/lite/kernels/kernel_util.h"
48#include "tensorflow/lite/minimal_logging.h"
49
50
51namespace tflite {
52namespace gpu {
53namespace metal {
54namespace {
55
56// Multi-thread safe alarm clock for preventing GPU sleeping. It spawns lightweight compute tasks
57// until no inference is performing on a device. It's reduces the CPU-to-CPU inference latency.
58// The class is used only for kAggressive wait type.
59class GpuAlarmClock {
60 public:
61  explicit GpuAlarmClock(id<MTLCommandQueue> command_queue) {
62    auto device = [command_queue device];
63    std::lock_guard<std::mutex> lock(alarms_mutex_);
64    if (!alarms_) alarms_ = new std::map<id<MTLDevice>, GpuAlarmClockInternal*>();
65    auto it = alarms_->find(device);
66    if (it == alarms_->end()) {
67      internal_ = new GpuAlarmClockInternal(command_queue);
68      (*alarms_)[device] = internal_;
69    } else {
70      internal_ = it->second;
71      internal_->total_alarms_++;
72    }
73  }
74  ~GpuAlarmClock() {
75    std::lock_guard<std::mutex> lock(alarms_mutex_);
76    if (--internal_->total_alarms_ > 0) return;
77    Stop();
78    delete internal_;
79    // Remove the alarm from the container to free-up device handle.
80    for (auto it = alarms_->begin(); it != alarms_->end(); ++it) {
81      if (it->second == internal_) {
82        alarms_->erase(it);
83        break;
84      }
85    }
86    if (alarms_->empty()) {
87      delete alarms_;
88      alarms_ = nullptr;
89    }
90  }
91  void Start() {
92    if (started_) return;
93    started_ = true;
94    internal_->active_alarms_++;
95  }
96  void Stop() {
97    if (!started_) return;
98    started_ = false;
99    internal_->active_alarms_--;
100  }
101
102 private:
103  class GpuAlarmClockInternal {
104   public:
105    id<MTLComputePipelineState> stub_program_;
106    id<MTLBuffer> stub_buffer_;
107    explicit GpuAlarmClockInternal(id<MTLCommandQueue> command_queue) {
108      command_queue_ = command_queue;
109      device_ = [command_queue_ device];
110      total_alarms_ = 1;
111      NSString* error;
112      id<MTLComputePipelineState> program;
113      // TODO(impjdi): Properly handle returned status.
114      CreateComputeProgram(device_,
115                           @"kernel void ComputeFunction(device int* output_buffer [[buffer(0)]]) "
116                           @"{ output_buffer[0] = 0; }",
117                           @"ComputeFunction", nullptr, &program)
118          .IgnoreError();
119      stub_program_ = program;
120      stub_buffer_ = [device_ newBufferWithLength:sizeof(int) * 4
121                                          options:MTLResourceHazardTrackingModeUntracked];
122      alarm_thread_ = std::thread([this]() {
123        id<MTLCommandBuffer> prev_command_buffer;
124        while (!release_thread_) {
125          @autoreleasepool {
126            if (active_alarms_ == total_alarms_) {
127              id<MTLCommandBuffer> command_buffer = [command_queue_ commandBuffer];
128              id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder];
129              [encoder setComputePipelineState:stub_program_];
130              [encoder setBuffer:stub_buffer_ offset:0 atIndex:0];
131              [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1)
132                      threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
133              [encoder endEncoding];
134              [command_buffer commit];
135              if (prev_command_buffer != nil) [prev_command_buffer waitUntilScheduled];
136              prev_command_buffer = command_buffer;
137            } else {
138              std::this_thread::sleep_for(std::chrono::milliseconds(1));
139            }
140          }
141        }
142      });
143    }
144    ~GpuAlarmClockInternal() {
145      release_thread_ = true;
146      alarm_thread_.join();
147    }
148
149   private:
150    friend class GpuAlarmClock;
151    std::atomic<int> active_alarms_;
152    std::thread alarm_thread_;
153    id<MTLCommandQueue> command_queue_;
154    id<MTLDevice> device_;
155    volatile bool release_thread_ = false;
156    int total_alarms_ = 0;
157  };
158  static std::map<id<MTLDevice>, GpuAlarmClockInternal*>* alarms_;
159  std::mutex alarms_mutex_;
160  GpuAlarmClockInternal* internal_;
161  bool started_ = false;
162};
163std::map<id<MTLDevice>, GpuAlarmClock::GpuAlarmClockInternal*>* GpuAlarmClock::alarms_ = nullptr;
164
165// Forward declaration.
166TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate);
167
168class Delegate {
169  struct ValueRef {
170    BHWC shape;
171    int64_t tensor_id;
172  };
173
174 public:
175  explicit Delegate(const TFLGpuDelegateOptions* options) {
176    if (options) {
177      options_ = *options;
178    } else {
179      options_ = TFLGpuDelegateOptionsDefault();
180    }
181    metal_device_ = MTLCreateSystemDefaultDevice();
182    command_queue_ = [metal_device_ newCommandQueue];
183    if (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeAggressive) {
184      gpu_alarm_clock_ = std::unique_ptr<GpuAlarmClock>(new GpuAlarmClock(command_queue_));
185      NSString* code = @R"(
186          kernel void ComputeFunction(device int* output_buffer [[buffer(0)]],
187                                      constant int& value [[buffer(1)]]) {
188            output_buffer[0] = value;
189          }
190        )";
191      NSString* error;
192      id<MTLComputePipelineState> signal_program;
193      // TODO(impjdi): Properly handle returned status.
194      CreateComputeProgram(metal_device_, code, @"ComputeFunction", nullptr, &signal_program)
195          .IgnoreError();
196      signal_program_ = signal_program;
197      signal_buffer_ = [metal_device_ newBufferWithLength:sizeof(int) * 4
198                                                  options:MTLResourceStorageModeShared |
199                                                          MTLResourceHazardTrackingModeUntracked];
200    }
201  }
202
203  absl::Status BindBufferToTensor(id<MTLBuffer> buffer, int tensor_index) {
204    // The tensor index is expected to be an input or output tensor of the interpreter.
205    // For quantized model, the buffer should be linked with their dequantized counterpart.
206    if (quant_conversion_map_.find(tensor_index) != quant_conversion_map_.end()) {
207      tensor_index = quant_conversion_map_[tensor_index];
208      // remove [dequantized tensor ID] -> [quantized tensor ID] mapping, to prevent extra
209      // dequant/quant on in/outputs.
210      quant_conversion_map_.erase(tensor_index);
211    }
212    for (auto& input : graph_inputs_) {
213      if (input.tensor_id == tensor_index) {
214        input_output_buffers_[input.id] = buffer;
215        if (bphwc4_buffers_[input.id] != buffer) {
216          bphwc_buffers_updated_ = true;
217        }
218        bphwc4_buffers_[input.id] = buffer;
219        input.set_externally = true;
220        return absl::OkStatus();
221      }
222    }
223    for (auto& output : graph_outputs_) {
224      if (output.tensor_id == tensor_index) {
225        input_output_buffers_[output.id] = buffer;
226        if (bphwc4_buffers_[output.id] != buffer) {
227          bphwc_buffers_updated_ = true;
228        }
229        bphwc4_buffers_[output.id] = buffer;
230        output.set_externally = true;
231        return absl::OkStatus();
232      }
233    }
234    return absl::NotFoundError("Couldn't find tensor: " + std::to_string(tensor_index));
235  }
236
237  void SetCommandBuffer(id<MTLCommandBuffer> command_buffer) {
238    external_command_buffer_ = command_buffer;
239  }
240
241  // This directs the runtime to allocate memory for input/output temporary
242  // tensors that require dequantization/quantization.
243  absl::Status GetRequiredTemporaries(TfLiteContext* context, TfLiteNode* node,
244                                      TfLiteIntArray** temporaries_array_ptr) {
245    if (quant_conversion_map_.empty()) return absl::OkStatus();
246
247    std::vector<int> temporary_tensor_ids;
248    for (auto index : input_tensor_ids_) {
249      if (quant_conversion_map_.find(index) != quant_conversion_map_.end()) {
250        temporary_tensor_ids.push_back(index);
251      }
252    }
253    for (auto index : output_tensor_ids_) {
254      if (quant_conversion_map_.find(index) != quant_conversion_map_.end()) {
255        temporary_tensor_ids.push_back(index);
256      }
257    }
258    *temporaries_array_ptr = TfLiteIntArrayCreate(temporary_tensor_ids.size());
259    for (int i = 0; i < temporary_tensor_ids.size(); ++i) {
260      (*temporaries_array_ptr)->data[i] = temporary_tensor_ids[i];
261    }
262    return absl::OkStatus();
263  }
264
265  absl::Status Prepare(TfLiteContext* context, const TfLiteDelegateParams* delegate_params) {
266    // Extract TFLite delegate execution plan from the context and convert it into GraphFloat32.
267    GraphFloat32 graph;
268    quant_conversion_map_.clear();
269    if (options_.enable_quantization) {
270      RETURN_IF_ERROR(BuildFinalModel(context, delegate_params, &graph, &quant_conversion_map_));
271    } else {
272      RETURN_IF_ERROR(BuildFinalModel(context, delegate_params, &graph));
273    }
274
275    // TODO(impjdi): Remove code duplication.
276    auto values = graph.values();
277    auto find_value = [&](int tensor_index) -> Value* {
278      for (auto value : values) {
279        if (value->tensor.ref == tensor_index) return value;
280      }
281      return nullptr;
282    };
283    tensors_.reserve(values.back()->id + 1);
284    for (const auto* value : values) {
285      if (tensors_.size() <= value->id) tensors_.resize(value->id + 1);
286      tensors_[value->id] = {
287          value->tensor.shape,  // .shape
288          value->tensor.ref,    // .tensor_id
289      };
290    }
291
292    // Prepare graph inputs.
293    //
294    // Note that graph.inputs() cannot be used directly, as the notion of graph input has a
295    // different meaning in public API and GPU-internal API.
296    for (int tensor_index : TfLiteIntArrayView(delegate_params->input_tensors)) {
297      auto* tensor = &context->tensors[tensor_index];
298      if (IsConstantTensor(tensor)) continue;
299      // For quantized models, actual inputs of GPU graph are float tensors, so the 8-bit inputs
300      // to the delegate kernel need to be dequantized berfore feeding to the GPU graph.
301      if (options_.enable_quantization &&
302          quant_conversion_map_.find(tensor_index) != quant_conversion_map_.end()) {
303        tensor_index = quant_conversion_map_[tensor_index];
304        tensor = &context->tensors[tensor_index];
305      }
306      const auto* input = find_value(tensor_index);
307      if (!input || tensor->type != TfLiteType::kTfLiteFloat32) {
308        return absl::NotFoundError("Input tensor is not found in the graph.");
309      }
310
311      inputs_.push_back(input->id);
312      input_tensor_ids_.push_back(tensor_index);
313      tensor->buffer_handle = input->id;
314      tensor->delegate = &delegate_;
315    }
316
317    // Prepare graph outputs.
318    //
319    // Note that graph.outputs() cannot be used directly, as the notion of graph output has a
320    // different meaning in public API and GPU-internal API.
321    for (int tensor_index : TfLiteIntArrayView(delegate_params->output_tensors)) {
322      auto* tensor = &context->tensors[tensor_index];
323      if (IsConstantTensor(tensor)) continue;
324      // For quantized models, actual outputs of GPU graph are float tensors, so they should be
325      // quantized to be the 8-bit outputs of delegate.
326      if (options_.enable_quantization &&
327          quant_conversion_map_.find(tensor_index) != quant_conversion_map_.end()) {
328        tensor_index = quant_conversion_map_[tensor_index];
329        tensor = &context->tensors[tensor_index];
330      }
331      const auto* output = find_value(tensor_index);
332      if (!output || tensor->type != TfLiteType::kTfLiteFloat32) {
333        return absl::NotFoundError("Output tensor is not found in the graph.");
334      }
335
336      outputs_.push_back(output->id);
337      output_tensor_ids_.push_back(tensor_index);
338      tensor->buffer_handle = output->id;
339      tensor->delegate = &delegate_;
340    }
341
342    std::string device_name = std::string([[metal_device_ name] UTF8String]);
343    GpuInfo gpu_info;
344    GetGpuInfoFromDeviceDescription(device_name, GpuApi::kMetal, &gpu_info);
345    size_t storage_type_size;
346    CalculationsPrecision precision;
347    if (options_.allow_precision_loss) {
348      storage_type_size = sizeof(HalfBits);
349      if (gpu_info.IsRoundToNearestSupported()) {
350        precision = CalculationsPrecision::F16;
351      } else {
352        precision = CalculationsPrecision::F32_F16;
353      }
354    } else {
355      storage_type_size = sizeof(float);
356      precision = CalculationsPrecision::F32;
357    }
358
359    // TODO(impjdi): Merge logic with above.
360    // Pre-allocate input and output metal buffers
361    std::vector<::tflite::gpu::ValueId> input_ids;
362    input_ids.reserve(inputs_.size());
363    std::map<::tflite::gpu::ValueId, BHWC> input_dimensions;
364    graph_inputs_.reserve(inputs_.size());
365    for (const ValueId input : inputs_) {
366      const auto& input_tensor = tensors_[input];
367      const auto tensor_id = input_tensor.tensor_id;
368      input_ids.push_back(input);
369      if (input_tensor.shape.b != 1) {
370        return absl::UnimplementedError("Batching is not supported yet.");
371      }
372      input_dimensions[input] = input_tensor.shape;
373      graph_inputs_.push_back({
374          input,               // .id
375          tensor_id,           // .tensor_id
376          input_tensor.shape,  // .shape
377          false,               // .set_externally
378      });
379      int bhwc_length = static_cast<int>(sizeof(float) * input_tensor.shape.DimensionsProduct());
380      int bphwc4_length =
381          static_cast<int>(storage_type_size * GetElementsSizeForPHWC4(input_tensor.shape));
382      id<MTLBuffer> buffer = [metal_device_ newBufferWithLength:bhwc_length
383                                                        options:MTLResourceStorageModeShared];
384      input_output_buffers_[input] = buffer;
385      if (options_.allow_precision_loss || input_tensor.shape.c != 4) {
386        bphwc4_buffers_[input] = [metal_device_ newBufferWithLength:bphwc4_length
387                                                            options:MTLResourceStorageModeShared];
388        if (converter_to_BPHWC4_ == nil) {
389          converter_to_BPHWC4_ =
390              [[TFLBufferConvert alloc] initWithDevice:metal_device_
391                                             isFloat16:options_.allow_precision_loss
392                                       convertToPBHWC4:true];
393          if (converter_to_BPHWC4_ == nil) {
394            return absl::InternalError("Error initialization of input buffer converter");
395          }
396        }
397      } else {
398        bphwc4_buffers_[input] = buffer;
399      }
400    }
401
402    std::vector<::tflite::gpu::ValueId> output_ids;
403    output_ids.reserve(outputs_.size());
404    graph_outputs_.reserve(outputs_.size());
405    for (const ValueId output : outputs_) {
406      const auto& output_tensor = tensors_[output];
407      const auto tensor_id = output_tensor.tensor_id;
408      output_ids.push_back(output);
409      graph_outputs_.push_back({
410          output,               // .id
411          tensor_id,            // .tensor_id
412          output_tensor.shape,  // .shape
413          false,                // .set_externally
414      });
415      // Create BHWC buffer
416      int bhwc_length = static_cast<int>(sizeof(float) * output_tensor.shape.DimensionsProduct());
417      int bphwc4_length =
418          static_cast<int>(storage_type_size * GetElementsSizeForPHWC4(output_tensor.shape));
419      id<MTLBuffer> buffer = [metal_device_ newBufferWithLength:bhwc_length
420                                                        options:MTLResourceStorageModeShared];
421      input_output_buffers_[output] = buffer;
422      if (options_.allow_precision_loss || output_tensor.shape.c != 4) {
423        bphwc4_buffers_[output] = [metal_device_ newBufferWithLength:bphwc4_length
424                                                             options:MTLResourceStorageModeShared];
425        if (converter_from_BPHWC4_ == nil) {
426          converter_from_BPHWC4_ =
427              [[TFLBufferConvert alloc] initWithDevice:metal_device_
428                                             isFloat16:options_.allow_precision_loss
429                                       convertToPBHWC4:false];
430          if (converter_from_BPHWC4_ == nil) {
431            return absl::InternalError("Error initialization of output buffer converter");
432          }
433        }
434      } else {
435        bphwc4_buffers_[output] = buffer;
436      }
437    }
438    bphwc_buffers_updated_ = true;
439
440    InferenceContext::CreateInferenceInfo create_info;
441    create_info.precision = precision;
442    create_info.storage_type = TensorStorageType::BUFFER;
443    RETURN_IF_ERROR(
444        inference_context_.InitFromGraphWithTransforms(create_info, &graph, metal_device_));
445    return absl::OkStatus();
446  }
447
448  absl::Status Invoke(TfLiteContext* context) {
449    if (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeAggressive)
450      gpu_alarm_clock_->Stop();
451    // We need only synchronization so volatile works better than atomic which reads from global
452    // memory each time.
453    __block volatile bool buffer_completed = false;
454    id<MTLCommandBuffer> command_buffer = external_command_buffer_;
455    if (external_command_buffer_ == nil) {
456      command_buffer = [command_queue_ commandBuffer];
457    }
458    const bool flush = external_command_buffer_ == nil &&
459        (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeActive ||
460         options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeAggressive);
461    const int flush_period = 8;
462
463    const bool is_quantized_model = !quant_conversion_map_.empty();
464    if (is_quantized_model) {
465      RETURN_IF_ERROR(DequantizeInputs(context, input_tensor_ids_, quant_conversion_map_));
466    }
467
468    // CPU HWC input data conversion to PHWC4 and fill the GPU buffer
469    for (const auto& input : graph_inputs_) {
470      if (input.set_externally) continue;
471      // A user provides data on CPU memory for this buffer - need to copy to MTLBuffer
472
473      TfLiteTensor* tensor = &context->tensors[input.tensor_id];
474      void* gpu_ptr = [input_output_buffers_[input.id] contents];
475      std::memcpy(gpu_ptr, tensor->data.f, input.shape.DimensionsProduct() * sizeof(float));
476      if (input_output_buffers_[input.id] == bphwc4_buffers_[input.id]) continue;
477      id<MTLComputeCommandEncoder> input_encoder = [command_buffer computeCommandEncoder];
478      [converter_to_BPHWC4_ convertWithEncoder:input_encoder
479                                         shape:input.shape
480                                  sourceBuffer:input_output_buffers_[input.id]
481                               convertedBuffer:bphwc4_buffers_[input.id]];
482      [input_encoder endEncoding];
483    }
484
485    if (bphwc_buffers_updated_) {
486      inference_context_.UpdatePreallocatedTensors(bphwc4_buffers_);
487      bphwc_buffers_updated_ = false;
488    }
489
490    @autoreleasepool {
491      if (flush) {
492        [command_buffer commit];
493        inference_context_.EncodeWithCommandQueue(command_queue_, flush_period);
494        command_buffer = [command_queue_ commandBuffer];
495      } else {
496        inference_context_.EncodeWithCommandBuffer(command_buffer);
497      }
498    }
499
500    for (const auto& output : graph_outputs_) {
501      if (output.set_externally) continue;
502      if (bphwc4_buffers_[output.id] == input_output_buffers_[output.id]) continue;
503      id<MTLComputeCommandEncoder> output_encoder = [command_buffer computeCommandEncoder];
504      [converter_from_BPHWC4_ convertWithEncoder:output_encoder
505                                           shape:output.shape
506                                    sourceBuffer:bphwc4_buffers_[output.id]
507                                 convertedBuffer:input_output_buffers_[output.id]];
508      [output_encoder endEncoding];
509    }
510
511    if (external_command_buffer_ == nil) {
512      if (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeActive) {
513        [command_buffer addCompletedHandler:^(id<MTLCommandBuffer>) {
514          buffer_completed = true;
515        }];
516      }
517      [command_buffer commit];
518      if (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeActive) {
519        while (!buffer_completed) {
520          // Busy wait. Use local variable. Volatile uses RAM access all the time.
521          for (volatile int i = 0; i < 100; i++) {
522          }
523        }
524      } else if (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypePassive) {
525        // passive wait: this thread sleeps until GPU finishes.
526        [command_buffer waitUntilCompleted];
527      } else if (options_.wait_type == TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeAggressive) {
528        id<MTLCommandBuffer> signal_cb = [command_queue_ commandBuffer];
529        id<MTLComputeCommandEncoder> signal_encoder = [signal_cb computeCommandEncoder];
530        [signal_encoder setComputePipelineState:signal_program_];
531        [signal_encoder setBuffer:signal_buffer_ offset:0 atIndex:0];
532        signal_value_++;
533        [signal_encoder setBytes:&signal_value_ length:sizeof(int) atIndex:1];
534        [signal_encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1)
535                threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
536        [signal_encoder endEncoding];
537        [signal_cb commit];
538        gpu_alarm_clock_->Start();
539        const int* signal_ptr = reinterpret_cast<const int*>([signal_buffer_ contents]);
540        while (signal_ptr[0] != signal_value_) {
541          // Busy wait. Spinning with local variable to avoid RAM pressure.
542          for (volatile int i = 0; i < 100; i++) {
543          }
544        }
545      }
546    } else {
547      // External command buffer must be set before every invoke call.
548      external_command_buffer_ = nil;
549      // External command buffer is assigned so all output buffers are controlled by a user.
550      for (const auto& output : graph_outputs_) {
551        if (!output.set_externally) {
552          return absl::InternalError(
553              "External command encoder is used, but not all output buffers are bound.");
554        }
555      }
556      return absl::OkStatus();
557    }
558
559    // Retrieve data from GPU and convert from PHWC4 to HWC.
560    for (const auto& output : graph_outputs_) {
561      if (output.set_externally) continue;
562      // A user retrieves data on CPU memory for this buffer - need to copy from MTLBuffer.
563      TfLiteTensor* tensor = context->tensors + output.tensor_id;
564      const void* gpu_ptr = [input_output_buffers_[output.id] contents];
565      std::memcpy(tensor->data.f, gpu_ptr, output.shape.DimensionsProduct() * sizeof(float));
566    }
567    if (is_quantized_model) {
568      RETURN_IF_ERROR(QuantizeOutputs(context, output_tensor_ids_, quant_conversion_map_));
569    }
570    return absl::OkStatus();
571  }
572
573  const TFLGpuDelegateOptions options() const { return options_; }
574
575  TfLiteDelegate* tflite_delegate() { return &delegate_; }
576
577 private:
578  TfLiteDelegate delegate_ = {
579      reinterpret_cast<void*>(this),  // .data_
580      DelegatePrepare,                // .Prepare
581      nullptr,                        // .CopyFromBufferHandle
582      nullptr,                        // .CopyToBufferHandle
583      nullptr,                        // .FreeBufferHandle
584      kTfLiteDelegateFlagsNone,       // .flags
585  };
586
587  TFLGpuDelegateOptions options_;
588
589  id<MTLDevice> metal_device_;
590
591  std::vector<ValueRef> tensors_;  // indexed by ValueId
592  std::vector<ValueId> inputs_;
593  std::vector<ValueId> outputs_;
594  std::vector<int64_t> input_tensor_ids_;
595  std::vector<int64_t> output_tensor_ids_;
596  // Whenever quantized inference is enabled, this maps the tensor index of each
597  // originally quantized (8-bit) tensor to its float version added in
598  // model_builder - and vice versa.
599  absl::flat_hash_map<int, int> quant_conversion_map_;
600
601  InferenceContext inference_context_;
602  // input and output buffers are passed into Metal inference engine
603  std::map<::tflite::gpu::ValueId, id<MTLBuffer>> input_output_buffers_;
604  std::map<::tflite::gpu::ValueId, id<MTLBuffer>> bphwc4_buffers_;
605  bool bphwc_buffers_updated_ = true;
606  TFLBufferConvert* converter_to_BPHWC4_ = nil;
607  TFLBufferConvert* converter_from_BPHWC4_ = nil;
608
609  struct BufferDescriptor {
610    ValueId id;
611    int64_t tensor_id;
612    BHWC shape;
613    bool set_externally;  // a user fills/retrieves data on this MTLBuffer buffer
614  };
615  std::vector<BufferDescriptor> graph_inputs_;
616  std::vector<BufferDescriptor> graph_outputs_;
617
618  id<MTLCommandBuffer> external_command_buffer_ = nil;
619  id<MTLCommandQueue> command_queue_;
620  std::unique_ptr<GpuAlarmClock> gpu_alarm_clock_;
621  id<MTLComputePipelineState> signal_program_;
622  id<MTLBuffer> signal_buffer_;
623  int signal_value_ = 0;
624};
625
626Delegate* GetMetalDelegate(TfLiteNode* node) {
627  return reinterpret_cast<Delegate*>(node->user_data);
628}
629
630Delegate* GetMetalDelegate(TfLiteDelegate* delegate) {
631  return reinterpret_cast<Delegate*>(delegate->data_);
632}
633
634TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) {
635  const TfLiteRegistration kRegistration = {
636      // .init
637      [](TfLiteContext* context, const char* buffer, size_t) -> void* {
638        const auto* params = reinterpret_cast<const TfLiteDelegateParams*>(buffer);
639        auto* metal_delegate = GetMetalDelegate(params->delegate);
640        // Everything below should happen in prepare function call, but TFLite for whatever reason
641        // forbids that.
642        const auto status = metal_delegate->Prepare(context, params);
643        if (status.ok()) return metal_delegate;
644        TF_LITE_KERNEL_LOG(context, "TfLiteMetalDelegate Prepare: %s",
645                           std::string(status.message()).c_str());
646        return nullptr;
647      },
648      // .free
649      [](TfLiteContext*, void* buffer) -> void {},
650      // .prepare
651      [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
652        if (!node->user_data) {
653          return kTfLiteError;
654        }
655
656        auto* gpu_delegate_kernel = GetMetalDelegate(node);
657        const auto status =
658            gpu_delegate_kernel->GetRequiredTemporaries(context, node, &node->temporaries);
659        if (!status.ok()) {
660          TF_LITE_KERNEL_LOG(context, "TfLiteMetalDelegate Prepare: %s",
661                             std::string(status.message()).c_str());
662          return kTfLiteError;
663        }
664        return node->user_data ? kTfLiteOk : kTfLiteError;
665      },
666      // .invoke
667      [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
668        const auto status = GetMetalDelegate(node)->Invoke(context);
669        if (status.ok()) return kTfLiteOk;
670        TF_LITE_KERNEL_LOG(context, "TfLiteMetalDelegate Invoke: %s",
671                           std::string(status.message()).c_str());
672        return kTfLiteError;
673      },
674      nullptr,                // .profiling_string
675      0,                      // .builtin_code
676      "TfLiteMetalDelegate",  // .custom_name
677      1,                      // .version
678  };
679  TfLiteIntArray* ops_to_replace =
680      GetOpsToReplace(context, GetMetalDelegate(delegate)->options().enable_quantization);
681  const auto status = context->ReplaceNodeSubsetsWithDelegateKernels(context, kRegistration,
682                                                                     ops_to_replace, delegate);
683  TfLiteIntArrayFree(ops_to_replace);
684  return status;
685}
686
687}  // namespace
688}  // namespace metal
689}  // namespace gpu
690}  // namespace tflite
691
692TfLiteDelegate* TFLGpuDelegateCreate(const TFLGpuDelegateOptions* options) {
693  TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO, "Created TensorFlow Lite delegate for Metal.");
694  auto* metal_delegate = new ::tflite::gpu::metal::Delegate(options);
695  return metal_delegate ? metal_delegate->tflite_delegate() : nullptr;
696}
697
698void TFLGpuDelegateDelete(TfLiteDelegate* delegate) {
699  delete ::tflite::gpu::metal::GetMetalDelegate(delegate);
700}
701
702bool TFLGpuDelegateBindMetalBufferToTensor(TfLiteDelegate* delegate, int tensor_index,
703                                           id<MTLBuffer> buffer) {
704  auto* metal_delegate = ::tflite::gpu::metal::GetMetalDelegate(delegate);
705  return metal_delegate && metal_delegate->BindBufferToTensor(buffer, tensor_index).ok();
706}
707
708// Note: This function is not exposed in `metal_delegate.h`, but it's exposed in
709// `metal_delegate_internal.h`.
710bool TFLGpuDelegateSetCommandBuffer(TfLiteDelegate* delegate,
711                                    id<MTLCommandBuffer> command_buffer) {
712  auto* metal_delegate = ::tflite::gpu::metal::GetMetalDelegate(delegate);
713  if (!metal_delegate) return false;
714  metal_delegate->SetCommandBuffer(command_buffer);
715  return true;
716}
717
718TFLGpuDelegateOptions TFLGpuDelegateOptionsDefault() {
719  TFLGpuDelegateOptions options = {
720      .allow_precision_loss = false,
721      .wait_type = TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypePassive,
722      .enable_quantization = true,
723  };
724  return options;
725}
726