1 // 2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include "DelegateOptions.hpp" 9 10 #include <tensorflow/lite/builtin_ops.h> 11 #include <tensorflow/lite/c/builtin_op_data.h> 12 #include <tensorflow/lite/c/common.h> 13 #include <tensorflow/lite/minimal_logging.h> 14 15 namespace armnnDelegate 16 { 17 18 struct DelegateData 19 { DelegateDataarmnnDelegate::DelegateData20 DelegateData(const std::vector<armnn::BackendId>& backends) 21 : m_Backends(backends) 22 , m_Network(nullptr, nullptr) 23 {} 24 25 const std::vector<armnn::BackendId> m_Backends; 26 armnn::INetworkPtr m_Network; 27 std::vector<armnn::IOutputSlot*> m_OutputSlotForNode; 28 }; 29 30 // Forward decleration for functions initializing the ArmNN Delegate 31 DelegateOptions TfLiteArmnnDelegateOptionsDefault(); 32 33 TfLiteDelegate* TfLiteArmnnDelegateCreate(armnnDelegate::DelegateOptions options); 34 35 void TfLiteArmnnDelegateDelete(TfLiteDelegate* tfLiteDelegate); 36 37 TfLiteStatus DoPrepare(TfLiteContext* context, TfLiteDelegate* delegate); 38 39 /// ArmNN Delegate 40 class Delegate 41 { 42 friend class ArmnnSubgraph; 43 public: 44 explicit Delegate(armnnDelegate::DelegateOptions options); 45 46 TfLiteIntArray* IdentifyOperatorsToDelegate(TfLiteContext* context); 47 48 TfLiteDelegate* GetDelegate(); 49 50 private: 51 TfLiteDelegate m_Delegate = { 52 reinterpret_cast<void*>(this), // .data_ 53 DoPrepare, // .Prepare 54 nullptr, // .CopyFromBufferHandle 55 nullptr, // .CopyToBufferHandle 56 nullptr, // .FreeBufferHandle 57 kTfLiteDelegateFlagsNone, // .flags 58 }; 59 60 /// ArmNN Runtime pointer 61 armnn::IRuntimePtr m_Runtime; 62 /// ArmNN Delegate Options 63 armnnDelegate::DelegateOptions m_Options; 64 }; 65 66 /// ArmnnSubgraph class where parsing the nodes to ArmNN format and creating the ArmNN Graph 67 class ArmnnSubgraph 68 { 69 public: 70 static ArmnnSubgraph* Create(TfLiteContext* tfLiteContext, 71 const TfLiteDelegateParams* parameters, 72 const Delegate* delegate); 73 74 TfLiteStatus Prepare(TfLiteContext* tfLiteContext); 75 76 TfLiteStatus Invoke(TfLiteContext* tfLiteContext, TfLiteNode* tfLiteNode); 77 78 static TfLiteStatus VisitNode(DelegateData& delegateData, 79 TfLiteContext* tfLiteContext, 80 TfLiteRegistration* tfLiteRegistration, 81 TfLiteNode* tfLiteNode, 82 int nodeIndex); 83 84 private: ArmnnSubgraph(armnn::NetworkId networkId,armnn::IRuntime * runtime,std::vector<armnn::BindingPointInfo> & inputBindings,std::vector<armnn::BindingPointInfo> & outputBindings)85 ArmnnSubgraph(armnn::NetworkId networkId, 86 armnn::IRuntime* runtime, 87 std::vector<armnn::BindingPointInfo>& inputBindings, 88 std::vector<armnn::BindingPointInfo>& outputBindings) 89 : m_NetworkId(networkId), m_Runtime(runtime), m_InputBindings(inputBindings), m_OutputBindings(outputBindings) 90 {} 91 92 static TfLiteStatus AddInputLayer(DelegateData& delegateData, 93 TfLiteContext* tfLiteContext, 94 const TfLiteIntArray* inputs, 95 std::vector<armnn::BindingPointInfo>& inputBindings); 96 97 static TfLiteStatus AddOutputLayer(DelegateData& delegateData, 98 TfLiteContext* tfLiteContext, 99 const TfLiteIntArray* outputs, 100 std::vector<armnn::BindingPointInfo>& outputBindings); 101 102 103 /// The Network Id 104 armnn::NetworkId m_NetworkId; 105 /// ArmNN Rumtime 106 armnn::IRuntime* m_Runtime; 107 108 // Binding information for inputs and outputs 109 std::vector<armnn::BindingPointInfo> m_InputBindings; 110 std::vector<armnn::BindingPointInfo> m_OutputBindings; 111 112 }; 113 114 } // armnnDelegate namespace 115 116 117