• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/delegate.h"
17 
18 #include <cstdint>
19 #include <memory>
20 #include <thread>  // NOLINT(build/c++11)
21 #include <vector>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/memory/memory.h"
25 #include "absl/types/span.h"
26 #include "tensorflow/lite/builtin_ops.h"
27 #include "tensorflow/lite/c/common.h"
28 #include "tensorflow/lite/delegates/gpu/api.h"
29 #include "tensorflow/lite/delegates/gpu/cl/api.h"
30 #include "tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h"
31 #include "tensorflow/lite/delegates/gpu/cl/tensor_type_util.h"
32 #include "tensorflow/lite/delegates/gpu/common/model.h"
33 #include "tensorflow/lite/delegates/gpu/common/model_builder.h"
34 #include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
35 #include "tensorflow/lite/delegates/gpu/common/quantization_util.h"
36 #include "tensorflow/lite/delegates/gpu/common/status.h"
37 #include "tensorflow/lite/delegates/serialization.h"
38 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
39 #include "tensorflow/lite/minimal_logging.h"
40 
41 #ifndef CL_DELEGATE_NO_GL
42 #include "tensorflow/lite/delegates/gpu/gl/api2.h"
43 #endif
44 
45 namespace tflite {
46 namespace gpu {
47 namespace {
48 
49 using delegates::Serialization;
50 using delegates::SerializationParams;
51 
52 constexpr char kSerializedDataPrefix[] = "gpuv2_data_";
53 
ToPriority(int32_t priority)54 InferencePriority ToPriority(int32_t priority) {
55   switch (priority) {
56     case TFLITE_GPU_INFERENCE_PRIORITY_AUTO:
57       return InferencePriority::AUTO;
58     case TFLITE_GPU_INFERENCE_PRIORITY_MAX_PRECISION:
59       return InferencePriority::MAX_PRECISION;
60     case TFLITE_GPU_INFERENCE_PRIORITY_MIN_LATENCY:
61       return InferencePriority::MIN_LATENCY;
62     case TFLITE_GPU_INFERENCE_PRIORITY_MIN_MEMORY_USAGE:
63       return InferencePriority::MIN_MEMORY_USAGE;
64   }
65   return InferencePriority::UNKNOWN;
66 }
67 
ToUsage(int32_t usage)68 InferenceUsage ToUsage(int32_t usage) {
69   switch (usage) {
70     case TFLITE_GPU_INFERENCE_PREFERENCE_FAST_SINGLE_ANSWER:
71       return InferenceUsage::FAST_SINGLE_ANSWER;
72     case TFLITE_GPU_INFERENCE_PREFERENCE_SUSTAINED_SPEED:
73       return InferenceUsage::SUSTAINED_SPEED;
74   }
75   return InferenceUsage::UNKNOWN;
76 }
77 
78 // Forward declarations.
79 TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate);
80 
81 class Delegate {
82  public:
Delegate(const TfLiteGpuDelegateOptionsV2 * options)83   explicit Delegate(const TfLiteGpuDelegateOptionsV2* options)
84       : num_delegate_kernels_(0) {
85     delegate_.data_ = reinterpret_cast<void*>(this);
86     delegate_.Prepare = DelegatePrepare;
87     delegate_.CopyFromBufferHandle = nullptr;
88     delegate_.CopyToBufferHandle = nullptr;
89     delegate_.FreeBufferHandle = nullptr;
90     delegate_.flags = kTfLiteDelegateFlagsNone;
91     options_ = options ? *options : TfLiteGpuDelegateOptionsV2Default();
92     if (options_.max_delegated_partitions <= 0) {
93       options_.max_delegated_partitions = 1;
94     }
95     if (options->experimental_flags &
96             TFLITE_GPU_EXPERIMENTAL_FLAGS_ENABLE_SERIALIZATION &&
97         options->model_token && options->serialization_dir) {
98       SerializationParams params;
99       params.model_token = options->model_token;
100       params.cache_dir = options->serialization_dir;
101       serialization_.reset(new Serialization(params));
102     }
103   }
104 
tflite_delegate()105   TfLiteDelegate* tflite_delegate() { return &delegate_; }
serialization()106   Serialization* serialization() { return serialization_.get(); }
options() const107   const TfLiteGpuDelegateOptionsV2& options() const { return options_; }
108 
IsQuantOpsAllowed() const109   bool IsQuantOpsAllowed() const {
110     return options_.experimental_flags &
111            TFLITE_GPU_EXPERIMENTAL_FLAGS_ENABLE_QUANT;
112   }
MaxDelegatedPartitions() const113   int MaxDelegatedPartitions() const {
114     return options_.max_delegated_partitions;
115   }
num_delegate_kernels() const116   int num_delegate_kernels() const { return num_delegate_kernels_; }
117 
118  private:
119   TfLiteDelegate delegate_;
120   TfLiteGpuDelegateOptionsV2 options_;
121   int num_delegate_kernels_ = 0;
122 
123   std::unique_ptr<Serialization> serialization_;
124 
125   friend class DelegateKernel;
126 };
127 
128 // Represent the execution of a subset of nodes on GPU.
129 class DelegateKernel {
130  public:
DelegateKernel(Delegate * delegate)131   explicit DelegateKernel(Delegate* delegate) : delegate_(delegate) {
132     ++delegate_->num_delegate_kernels_;
133   }
~DelegateKernel()134   ~DelegateKernel() { --delegate_->num_delegate_kernels_; }
135 
Prepare(TfLiteContext * context,const TfLiteDelegateParams * delegate_params)136   absl::Status Prepare(TfLiteContext* context,
137                        const TfLiteDelegateParams* delegate_params) {
138     thread_id_prepare_ = std::this_thread::get_id();
139 
140     // Extract TFLite delegate execution plan from the context and convert it
141     // into GraphFloat32.
142     GraphFloat32 graph;
143     std::vector<uint32_t> input_refs;
144     std::vector<uint32_t> output_refs;
145     RETURN_IF_ERROR(InitializeGraph(context, delegate_params, &graph,
146                                     &input_refs, &output_refs));
147 
148     std::unique_ptr<InferenceBuilder> builder;
149     bool graph_is_destroyed;
150     const int experimental_flags = delegate_->options().experimental_flags;
151     if (experimental_flags & TFLITE_GPU_EXPERIMENTAL_FLAGS_CL_ONLY) {
152       RETURN_IF_ERROR(InitializeOpenClApi(&graph, &builder, &graph_is_destroyed,
153                                           context, delegate_params,
154                                           delegate_->serialization()));
155     } else if (experimental_flags & TFLITE_GPU_EXPERIMENTAL_FLAGS_GL_ONLY) {
156       RETURN_IF_ERROR(InitializeOpenGlApi(&graph, &builder));
157     } else {
158       // By default, we try CL first & fall back to GL if that fails.
159       absl::Status status =
160           InitializeOpenClApi(&graph, &builder, &graph_is_destroyed, context,
161                               delegate_params, delegate_->serialization());
162       if (!status.ok()) {
163         TF_LITE_KERNEL_LOG(context, std::string(status.message()).c_str());
164         TF_LITE_KERNEL_LOG(context, "Falling back to OpenGL");
165 
166         // Graph needs to be re-created because it is moved above.
167         GraphFloat32 graph2;
168         if (graph_is_destroyed) {
169           RETURN_IF_ERROR(InitializeGraph(context, delegate_params, &graph2,
170                                           &input_refs, &output_refs));
171         }
172         RETURN_IF_ERROR(InitializeOpenGlApi(
173             graph_is_destroyed ? &graph2 : &graph, &builder));
174       }
175     }
176 
177     // At this point tflite didn't allocate tensors yet, therefore, collect
178     // indices and set all input and output tensors from tflite later.
179     input_indices_.reserve(input_refs.size());
180     for (uint32_t tensor_index : input_refs) {
181       const int64_t object_index = input_indices_.size();
182       input_indices_.push_back(tensor_index);
183       RETURN_IF_ERROR(
184           builder->SetInputObjectDef(object_index, GetObjectDef(tensor_index)));
185     }
186     output_indices_.reserve(output_refs.size());
187     for (uint32_t tensor_index : output_refs) {
188       const int64_t object_index = output_indices_.size();
189       output_indices_.push_back(tensor_index);
190       RETURN_IF_ERROR(builder->SetOutputObjectDef(object_index,
191                                                   GetObjectDef(tensor_index)));
192     }
193 
194     return builder->Build(&runner_);
195   }
196 
197   // This directs the runtime to allocate memory for input/output temporary
198   // tensors that require dequantization/quantization.
GetRequiredTemporaries(TfLiteContext * context,TfLiteNode * node,TfLiteIntArray ** temporaries_array_ptr)199   absl::Status GetRequiredTemporaries(TfLiteContext* context, TfLiteNode* node,
200                                       TfLiteIntArray** temporaries_array_ptr) {
201     if (quant_conversion_map_.empty()) return absl::OkStatus();
202 
203     std::vector<int> temporary_tensors;
204     for (auto index : input_indices_) {
205       if (quant_conversion_map_.find(index) != quant_conversion_map_.end()) {
206         temporary_tensors.push_back(index);
207       }
208     }
209     for (auto index : output_indices_) {
210       if (quant_conversion_map_.find(index) != quant_conversion_map_.end()) {
211         temporary_tensors.push_back(index);
212       }
213     }
214     *temporaries_array_ptr = TfLiteIntArrayCreate(temporary_tensors.size());
215     for (int i = 0; i < temporary_tensors.size(); ++i) {
216       (*temporaries_array_ptr)->data[i] = temporary_tensors[i];
217     }
218     return absl::OkStatus();
219   }
220 
Invoke(TfLiteContext * context)221   absl::Status Invoke(TfLiteContext* context) {
222     if (thread_id_prepare_ != std::this_thread::get_id()) {
223       TFLITE_LOG(tflite::TFLITE_LOG_WARNING,
224                  "GpuDelegate invoke thread != prepare thread");
225       if (enforce_same_thread_) {
226         return absl::FailedPreconditionError(
227             "GpuDelegate must run on the same thread where it was "
228             "initialized.");
229       }
230     }
231 
232     const bool is_dequant_required = !quant_conversion_map_.empty();
233     if (is_dequant_required) {
234       RETURN_IF_ERROR(
235           DequantizeInputs(context, input_indices_, quant_conversion_map_));
236     }
237     RETURN_IF_ERROR(SetInputsAndOutputs(context));
238     RETURN_IF_ERROR(runner_->Run());
239     if (is_dequant_required) {
240       RETURN_IF_ERROR(
241           QuantizeOutputs(context, output_indices_, quant_conversion_map_));
242     }
243     return absl::OkStatus();
244   }
245 
246  private:
SetInputsAndOutputs(TfLiteContext * context)247   absl::Status SetInputsAndOutputs(TfLiteContext* context) {
248     for (int i = 0; i < input_indices_.size(); ++i) {
249       RETURN_IF_ERROR(runner_->SetInputObject(
250           i, GetTensorObject(input_indices_[i], context)));
251     }
252     for (int i = 0; i < output_indices_.size(); ++i) {
253       RETURN_IF_ERROR(runner_->SetOutputObject(
254           i, GetTensorObject(output_indices_[i], context)));
255     }
256     return absl::OkStatus();
257   }
258 
GetObjectDef(int index) const259   ObjectDef GetObjectDef(int index) const {
260     ObjectDef default_object_def;
261     default_object_def.data_type = DataType::FLOAT32;
262     default_object_def.data_layout = DataLayout::BHWC;
263     default_object_def.object_type = ObjectType::CPU_MEMORY;
264     default_object_def.user_provided = true;
265     return default_object_def;
266   }
267 
GetTensorObject(int index,TfLiteContext * context) const268   TensorObject GetTensorObject(int index, TfLiteContext* context) const {
269     auto& tensor = context->tensors[index];
270     return MakeCpuMemory(absl::MakeSpan(tensor.data.raw, tensor.bytes));
271   }
272 
273  private:
InitializeGraph(TfLiteContext * context,const TfLiteDelegateParams * delegate_params,GraphFloat32 * graph,std::vector<uint32_t> * input_refs,std::vector<uint32_t> * output_refs)274   absl::Status InitializeGraph(TfLiteContext* context,
275                                const TfLiteDelegateParams* delegate_params,
276                                GraphFloat32* graph,
277                                std::vector<uint32_t>* input_refs,
278                                std::vector<uint32_t>* output_refs) {
279     quant_conversion_map_.clear();
280     if (delegate_->IsQuantOpsAllowed()) {
281       RETURN_IF_ERROR(BuildFinalModel(context, delegate_params, graph,
282                                       &quant_conversion_map_));
283     } else {
284       RETURN_IF_ERROR(BuildFinalModel(context, delegate_params, graph));
285     }
286 
287     input_refs->clear();
288     output_refs->clear();
289     const auto inputs = graph->inputs();
290     input_refs->reserve(inputs.size());
291     for (const auto& input : inputs) {
292       input_refs->push_back(input->tensor.ref);
293     }
294     const auto outputs = graph->outputs();
295     output_refs->reserve(outputs.size());
296     for (const auto& output : outputs) {
297       output_refs->push_back(output->tensor.ref);
298     }
299 
300     return absl::OkStatus();
301   }
302 
InitializeOpenClApi(GraphFloat32 * graph,std::unique_ptr<InferenceBuilder> * builder,bool * graph_is_destroyed,TfLiteContext * context,const TfLiteDelegateParams * delegate_params,Serialization * serialization=nullptr)303   absl::Status InitializeOpenClApi(GraphFloat32* graph,
304                                    std::unique_ptr<InferenceBuilder>* builder,
305                                    bool* graph_is_destroyed,
306                                    TfLiteContext* context,
307                                    const TfLiteDelegateParams* delegate_params,
308                                    Serialization* serialization = nullptr) {
309     *graph_is_destroyed = false;
310     cl::InferenceEnvironmentOptions env_options;
311     cl::InferenceEnvironmentProperties properties;
312 
313     // OpenCL initialization is parameterized by these InferenceOptions.
314     auto delegate_options = delegate_->options();
315     cl::InferenceOptions options;
316     // If is_precision_loss_allowed == -1, then just use priorities instead
317     // of paying attention to is_precision_loss_allowed value.
318     if (delegate_options.is_precision_loss_allowed == -1) {
319       options.priority1 = ToPriority(delegate_options.inference_priority1);
320       options.priority2 = ToPriority(delegate_options.inference_priority2);
321       options.priority3 = ToPriority(delegate_options.inference_priority3);
322     } else {
323       // Users set is_precision_loss_allowed explicitly, thus use it explicitly.
324       if (delegate_options.is_precision_loss_allowed == 0) {
325         options.priority1 = InferencePriority::MAX_PRECISION;
326       } else {
327         options.priority1 = InferencePriority::MIN_LATENCY;
328       }
329     }
330     options.usage = ToUsage(delegate_options.inference_preference);
331 
332     if (!serialization) {
333       // This path is faster when there is no serialization involved.
334       RETURN_IF_ERROR(cl::NewInferenceEnvironment(env_options, &cl_environment_,
335                                                   &properties));
336       *graph_is_destroyed = true;
337       RETURN_IF_ERROR(cl_environment_->NewInferenceBuilder(
338           options, std::move(*graph), builder));
339     } else {
340       // If serialization data is found, initialize CL from it & return early.
341       if (MaybeInitializeSerializedOpenCL(context, delegate_params, builder,
342                                           &options, &env_options, &properties,
343                                           serialization)
344               .ok())
345         return absl::OkStatus();
346 
347       RETURN_IF_ERROR(cl::NewInferenceEnvironment(env_options, &cl_environment_,
348                                                   &properties));
349       *graph_is_destroyed = true;
350       std::vector<uint8_t> serialized_model;
351       RETURN_IF_ERROR(cl_environment_->BuildSerializedModel(
352           options, std::move(*graph), &serialized_model));
353       RETURN_IF_ERROR(
354           cl_environment_->NewInferenceBuilder(serialized_model, builder));
355 
356       RETURN_IF_ERROR(SaveSerializedOpenCL(context, delegate_params, &options,
357                                            serialization, serialized_model));
358     }
359 
360     TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO,
361                          "Initialized OpenCL-based API.");
362     return absl::OkStatus();
363   }
364 
365   // Returns Ok only if serialized data is successsfully found.
MaybeInitializeSerializedOpenCL(TfLiteContext * context,const TfLiteDelegateParams * delegate_params,std::unique_ptr<InferenceBuilder> * builder,cl::InferenceOptions * options,cl::InferenceEnvironmentOptions * env_options,cl::InferenceEnvironmentProperties * properties,Serialization * serialization)366   absl::Status MaybeInitializeSerializedOpenCL(
367       TfLiteContext* context, const TfLiteDelegateParams* delegate_params,
368       std::unique_ptr<InferenceBuilder>* builder, cl::InferenceOptions* options,
369       cl::InferenceEnvironmentOptions* env_options,
370       cl::InferenceEnvironmentProperties* properties,
371       Serialization* serialization) {
372     if (!serialization) return absl::InvalidArgumentError("No serialization");
373     // We use a fingerprint of the options to ensure compatibility.
374     std::string options_fingerprint =
375         delegates::StrFingerprint(options, sizeof(cl::InferenceOptions));
376     auto data_key = serialization->GetEntryForKernel(
377         std::string(kSerializedDataPrefix) + options_fingerprint, context,
378         delegate_params);
379 
380     std::string model_data;
381     auto model_data_status = data_key.GetData(context, &model_data);
382     if (model_data_status == kTfLiteOk) {
383       absl::Span<const uint8_t> model_span = absl::Span<const uint8_t>{
384           reinterpret_cast<const uint8_t*>(model_data.data()),
385           model_data.size()};
386       RETURN_IF_ERROR(cl::NewInferenceEnvironment(
387           *env_options, &cl_environment_, properties));
388       RETURN_IF_ERROR(
389           cl_environment_->NewInferenceBuilder(model_span, builder));
390       TFLITE_LOG_PROD_ONCE(
391           tflite::TFLITE_LOG_INFO,
392           "Initialized OpenCL-based API from serialized data.");
393       return absl::OkStatus();
394     }
395 
396     return absl::NotFoundError("Serialization data not found");
397   }
398 
399   // Returns Ok only if serialization happens successfully.
SaveSerializedOpenCL(TfLiteContext * context,const TfLiteDelegateParams * delegate_params,cl::InferenceOptions * options,Serialization * serialization,const std::vector<uint8_t> & serialized_model)400   absl::Status SaveSerializedOpenCL(
401       TfLiteContext* context, const TfLiteDelegateParams* delegate_params,
402       cl::InferenceOptions* options, Serialization* serialization,
403       const std::vector<uint8_t>& serialized_model) {
404     if (!serialization) return absl::InvalidArgumentError("No serialization");
405     // We use a fingerprint of the options to ensure compatibility.
406     std::string options_fingerprint =
407         delegates::StrFingerprint(options, sizeof(cl::InferenceOptions));
408 
409     // Save data.
410     auto data_key = serialization->GetEntryForKernel(
411         std::string(kSerializedDataPrefix) + options_fingerprint, context,
412         delegate_params);
413     auto save_status = data_key.SetData(
414         context, reinterpret_cast<const char*>(serialized_model.data()),
415         serialized_model.size());
416     if (save_status != kTfLiteOk) {
417       return absl::InvalidArgumentError("Failed to save serialized data");
418     }
419     return absl::OkStatus();
420   }
421 
InitializeOpenGlApi(GraphFloat32 * graph,std::unique_ptr<InferenceBuilder> * builder)422   absl::Status InitializeOpenGlApi(GraphFloat32* graph,
423                                    std::unique_ptr<InferenceBuilder>* builder) {
424 #ifndef CL_DELEGATE_NO_GL
425     gl::InferenceEnvironmentOptions env_options;
426     gl::InferenceEnvironmentProperties properties;
427     RETURN_IF_ERROR(
428         NewInferenceEnvironment(env_options, &gl_environment_, &properties));
429     auto delegate_options = delegate_->options();
430     gl::InferenceOptions options;
431     options.usage = ToUsage(delegate_options.inference_preference);
432     options.priority1 = ToPriority(delegate_options.inference_priority1);
433     options.priority2 = ToPriority(delegate_options.inference_priority2);
434     options.priority3 = ToPriority(delegate_options.inference_priority3);
435     RETURN_IF_ERROR(gl_environment_->NewInferenceBuilder(std::move(*graph),
436                                                          options, builder));
437     enforce_same_thread_ = true;
438     TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO,
439                          "Initialized OpenGL-based API.");
440     return absl::OkStatus();
441 #else
442     return absl::UnavailableError("OpenGL-based API disabled");
443 #endif
444   }
445 
446   // The Delegate instance that's shared across all DelegateKernel instances.
447   Delegate* const delegate_;  // doesn't own the memory.
448   std::unique_ptr<cl::InferenceEnvironment> cl_environment_;
449 #ifndef CL_DELEGATE_NO_GL
450   std::unique_ptr<gl::InferenceEnvironment> gl_environment_;
451 #endif
452   std::unique_ptr<InferenceRunner> runner_;
453   std::vector<int64_t> input_indices_;
454   std::vector<int64_t> output_indices_;
455   // Whenever quantized inference is enabled, this maps the tensor index of each
456   // originally quantized (8-bit) tensor to its float version added in
457   // model_builder - and vice versa.
458   absl::flat_hash_map<int, int> quant_conversion_map_;
459   std::thread::id thread_id_prepare_;  // thread id used for Prapare()
460   bool enforce_same_thread_ = false;   // flag to enforce same thread for Invoke
461 };
462 
GetDelegateKernel(TfLiteNode * node)463 inline DelegateKernel* GetDelegateKernel(TfLiteNode* node) {
464   return reinterpret_cast<DelegateKernel*>(node->user_data);
465 }
466 
GetDelegate(TfLiteDelegate * delegate)467 inline Delegate* GetDelegate(TfLiteDelegate* delegate) {
468   return reinterpret_cast<Delegate*>(delegate->data_);
469 }
470 
DelegatePrepare(TfLiteContext * context,TfLiteDelegate * delegate)471 TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) {
472   const TfLiteRegistration kRegistration = {
473       // .init
474       [](TfLiteContext* context, const char* buffer, size_t) -> void* {
475         const auto* params =
476             reinterpret_cast<const TfLiteDelegateParams*>(buffer);
477         auto* gpu_delegate = GetDelegate(params->delegate);
478         // Everything below should happen in prepare function call, but TFLite
479         // for whatever reason forbids that.
480         auto gpu_delegate_kernel =
481             absl::make_unique<DelegateKernel>(gpu_delegate);
482         const auto status = gpu_delegate_kernel->Prepare(context, params);
483         if (!status.ok()) {
484           TF_LITE_KERNEL_LOG(context, "TfLiteGpuDelegate Init: %s",
485                              std::string(status.message()).c_str());
486           return nullptr;
487         }
488         return gpu_delegate_kernel.release();
489       },
490       // .free
491       [](TfLiteContext*, void* buffer) -> void {
492         delete reinterpret_cast<DelegateKernel*>(buffer);
493       },
494       // .prepare
495       [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
496         if (!node->user_data) {
497           TF_LITE_KERNEL_LOG(
498               context,
499               "TfLiteGpuDelegate Prepare: delegate is not initialized");
500           return kTfLiteError;
501         }
502         auto* gpu_delegate_kernel = GetDelegateKernel(node);
503         const auto status = gpu_delegate_kernel->GetRequiredTemporaries(
504             context, node, &node->temporaries);
505         if (!status.ok()) {
506           TF_LITE_KERNEL_LOG(context, "TfLiteGpuDelegate Prepare: %s",
507                              std::string(status.message()).c_str());
508           return kTfLiteError;
509         }
510         // TODO(akulik): tflite tensors are not allocated here either. It would
511         // be good to set inputs and outputs only once here instead of setting
512         // them every time in .invoke.
513         return kTfLiteOk;
514       },
515       // .invoke
516       [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
517         const auto status = GetDelegateKernel(node)->Invoke(context);
518         if (!status.ok()) {
519           TF_LITE_KERNEL_LOG(context, "TfLiteGpuDelegate Invoke: %s",
520                              std::string(status.message()).c_str());
521           return kTfLiteError;
522         }
523         return kTfLiteOk;
524       },
525       nullptr,                // .profiling_string
526       0,                      // .builtin_code
527       "TfLiteGpuDelegateV2",  // .custom_name
528       1,                      // .version
529   };
530 
531   auto* gpu_delegate = GetDelegate(delegate);
532   TfLiteIntArray* ops_to_replace =
533       GetOpsToReplace(context, gpu_delegate->IsQuantOpsAllowed(),
534                       gpu_delegate->MaxDelegatedPartitions());
535   const auto status = context->ReplaceNodeSubsetsWithDelegateKernels(
536       context, kRegistration, ops_to_replace, delegate);
537   TFLITE_LOG_PROD(TFLITE_LOG_INFO, "Created %d GPU delegate kernels.",
538                   gpu_delegate->num_delegate_kernels());
539   TfLiteIntArrayFree(ops_to_replace);
540   return status;
541 }
542 
543 }  // namespace
544 }  // namespace gpu
545 }  // namespace tflite
546 
TfLiteGpuDelegateOptionsV2Default()547 TfLiteGpuDelegateOptionsV2 TfLiteGpuDelegateOptionsV2Default() {
548   TfLiteGpuDelegateOptionsV2 options;
549   // set it to -1 to detect whether it was later adjusted.
550   options.is_precision_loss_allowed = -1;
551   options.inference_preference =
552       TFLITE_GPU_INFERENCE_PREFERENCE_FAST_SINGLE_ANSWER;
553   options.inference_priority1 = TFLITE_GPU_INFERENCE_PRIORITY_MAX_PRECISION;
554   options.inference_priority2 = TFLITE_GPU_INFERENCE_PRIORITY_AUTO;
555   options.inference_priority3 = TFLITE_GPU_INFERENCE_PRIORITY_AUTO;
556   options.experimental_flags = TFLITE_GPU_EXPERIMENTAL_FLAGS_ENABLE_QUANT;
557   options.max_delegated_partitions = 1;
558   options.model_token = nullptr;
559   options.serialization_dir = nullptr;
560   return options;
561 }
562 
TfLiteGpuDelegateV2Create(const TfLiteGpuDelegateOptionsV2 * options)563 TfLiteDelegate* TfLiteGpuDelegateV2Create(
564     const TfLiteGpuDelegateOptionsV2* options) {
565   auto* gpu_delegate = new tflite::gpu::Delegate(options);
566   TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO,
567                        "Created TensorFlow Lite delegate for GPU.");
568   return gpu_delegate ? gpu_delegate->tflite_delegate() : nullptr;
569 }
570 
TfLiteGpuDelegateV2Delete(TfLiteDelegate * delegate)571 void TfLiteGpuDelegateV2Delete(TfLiteDelegate* delegate) {
572   delete tflite::gpu::GetDelegate(delegate);
573 }
574