1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // This file has utilities that facilitates creating new delegates. 17 // - SimpleDelegateKernelInterface: Represents a Kernel which handles a subgraph 18 // to be delegated. It has Init/Prepare/Invoke which are going to be called 19 // during inference, similar to TFLite Kernels. Delegate owner should implement 20 // this interface to build/prepare/invoke the delegated subgraph. 21 // - SimpleDelegateInterface: 22 // This class wraps TFLiteDelegate and users need to implement the interface and 23 // then call TfLiteDelegateFactory::CreateSimpleDelegate(...) to get 24 // TfLiteDelegate* that can be passed to ModifyGraphWithDelegate and free it via 25 // TfLiteDelegateFactory::DeleteSimpleDelegate(...). 26 // or call TfLiteDelegateFactory::Create(...) to get a std::unique_ptr 27 // TfLiteDelegate that can also be passed to ModifyGraphWithDelegate, in which 28 // case TfLite interpereter takes the memory ownership of the delegate. 29 #ifndef TENSORFLOW_LITE_DELEGATES_UTILS_SIMPLE_DELEGATE_H_ 30 #define TENSORFLOW_LITE_DELEGATES_UTILS_SIMPLE_DELEGATE_H_ 31 32 #include <memory> 33 34 #include "tensorflow/lite/c/common.h" 35 36 namespace tflite { 37 38 using TfLiteDelegateUniquePtr = 39 std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>; 40 41 // Users should inherit from this class and implement the interface below. 42 // Each instance represents a single part of the graph (subgraph). 43 class SimpleDelegateKernelInterface { 44 public: ~SimpleDelegateKernelInterface()45 virtual ~SimpleDelegateKernelInterface() {} 46 47 // Initializes a delegated subgraph. 48 // The nodes in the subgraph are inside TfLiteDelegateParams->nodes_to_replace 49 virtual TfLiteStatus Init(TfLiteContext* context, 50 const TfLiteDelegateParams* params) = 0; 51 52 // Will be called by the framework. Should handle any needed preparation 53 // for the subgraph e.g. allocating buffers, compiling model. 54 // Returns status, and signalling any errors. 55 virtual TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) = 0; 56 57 // Actual subgraph inference should happen on this call. 58 // Returns status, and signalling any errors. 59 virtual TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) = 0; 60 }; 61 62 // Pure Interface that clients should implement. 63 // The Interface represents a delegate capabilities and provide factory 64 // for SimpleDelegateKernelInterface 65 // 66 // Clients should implement the following methods: 67 // - IsNodeSupportedByDelegate 68 // - Initialize 69 // - name 70 // - CreateDelegateKernelInterface 71 class SimpleDelegateInterface { 72 public: 73 // Options for configuring a delegate. 74 struct Options { 75 // Maximum number of delegated subgraph, values <=0 means unlimited. 76 int max_delegated_partitions = 0; 77 78 // The minimum number of nodes allowed in a delegated graph, values <=0 79 // means unlimited. 80 int min_nodes_per_partition = 0; 81 }; 82 ~SimpleDelegateInterface()83 virtual ~SimpleDelegateInterface() {} 84 85 // Returns true if 'node' is supported by the delegate. False otherwise. 86 virtual bool IsNodeSupportedByDelegate(const TfLiteRegistration* registration, 87 const TfLiteNode* node, 88 TfLiteContext* context) const = 0; 89 90 // Initialize the delegate before finding and replacing TfLite nodes with 91 // delegate kernels, for example, retrieving some TFLite settings from 92 // 'context'. 93 virtual TfLiteStatus Initialize(TfLiteContext* context) = 0; 94 95 // Returns a name that identifies the delegate. 96 // This name is used for debugging/logging/profiling. 97 virtual const char* Name() const = 0; 98 99 // Returns instance of an object that implements the interface 100 // SimpleDelegateKernelInterface. 101 // An instance of SimpleDelegateKernelInterface represents one subgraph to 102 // be delegated. 103 // Caller takes ownership of the returned object. 104 virtual std::unique_ptr<SimpleDelegateKernelInterface> 105 CreateDelegateKernelInterface() = 0; 106 107 // Returns SimpleDelegateInterface::Options which has the delegate options. 108 virtual SimpleDelegateInterface::Options DelegateOptions() const = 0; 109 }; 110 111 // Factory class that provides static methods to deal with SimpleDelegate 112 // creation and deletion. 113 class TfLiteDelegateFactory { 114 public: 115 // Creates TfLiteDelegate from the provided SimpleDelegateInterface. 116 // The returned TfLiteDelegate should be deleted using DeleteSimpleDelegate. 117 // A simple usage of the flags bit mask: 118 // CreateSimpleDelegate(..., kTfLiteDelegateFlagsAllowDynamicTensors | 119 // kTfLiteDelegateFlagsRequirePropagatedShapes) 120 static TfLiteDelegate* CreateSimpleDelegate( 121 std::unique_ptr<SimpleDelegateInterface> simple_delegate, 122 int64_t flags = kTfLiteDelegateFlagsNone); 123 124 // Deletes 'delegate' the passed pointer must be the one returned 125 // from CreateSimpleDelegate. 126 // This function will destruct the SimpleDelegate object too. 127 static void DeleteSimpleDelegate(TfLiteDelegate* delegate); 128 129 // A convenient function wrapping the above two functions and returning a 130 // std::unique_ptr type for auto memory management. Create(std::unique_ptr<SimpleDelegateInterface> simple_delegate)131 inline static TfLiteDelegateUniquePtr Create( 132 std::unique_ptr<SimpleDelegateInterface> simple_delegate) { 133 return TfLiteDelegateUniquePtr( 134 CreateSimpleDelegate(std::move(simple_delegate)), DeleteSimpleDelegate); 135 } 136 }; 137 138 } // namespace tflite 139 140 #endif // TENSORFLOW_LITE_DELEGATES_UTILS_SIMPLE_DELEGATE_H_ 141