• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "DelegateOptions.hpp"
9 
10 #include <tensorflow/lite/builtin_ops.h>
11 #include <tensorflow/lite/c/builtin_op_data.h>
12 #include <tensorflow/lite/c/common.h>
13 #include <tensorflow/lite/minimal_logging.h>
14 
15 namespace armnnDelegate
16 {
17 
18 struct DelegateData
19 {
DelegateDataarmnnDelegate::DelegateData20     DelegateData(const std::vector<armnn::BackendId>& backends)
21         : m_Backends(backends)
22         , m_Network(nullptr, nullptr)
23     {}
24 
25     const std::vector<armnn::BackendId>       m_Backends;
26     armnn::INetworkPtr                        m_Network;
27     std::vector<armnn::IOutputSlot*>          m_OutputSlotForNode;
28 };
29 
30 // Forward decleration for functions initializing the ArmNN Delegate
31 DelegateOptions TfLiteArmnnDelegateOptionsDefault();
32 
33 TfLiteDelegate* TfLiteArmnnDelegateCreate(armnnDelegate::DelegateOptions options);
34 
35 void TfLiteArmnnDelegateDelete(TfLiteDelegate* tfLiteDelegate);
36 
37 TfLiteStatus DoPrepare(TfLiteContext* context, TfLiteDelegate* delegate);
38 
39 /// ArmNN Delegate
40 class Delegate
41 {
42     friend class ArmnnSubgraph;
43 public:
44     explicit Delegate(armnnDelegate::DelegateOptions options);
45 
46     TfLiteIntArray* IdentifyOperatorsToDelegate(TfLiteContext* context);
47 
48     TfLiteDelegate* GetDelegate();
49 
50 private:
51     TfLiteDelegate m_Delegate = {
52         reinterpret_cast<void*>(this),  // .data_
53         DoPrepare,                      // .Prepare
54         nullptr,                        // .CopyFromBufferHandle
55         nullptr,                        // .CopyToBufferHandle
56         nullptr,                        // .FreeBufferHandle
57         kTfLiteDelegateFlagsNone,       // .flags
58     };
59 
60     /// ArmNN Runtime pointer
61     armnn::IRuntimePtr m_Runtime;
62     /// ArmNN Delegate Options
63     armnnDelegate::DelegateOptions m_Options;
64 };
65 
66 /// ArmnnSubgraph class where parsing the nodes to ArmNN format and creating the ArmNN Graph
67 class ArmnnSubgraph
68 {
69 public:
70     static ArmnnSubgraph* Create(TfLiteContext* tfLiteContext,
71                                  const TfLiteDelegateParams* parameters,
72                                  const Delegate* delegate);
73 
74     TfLiteStatus Prepare(TfLiteContext* tfLiteContext);
75 
76     TfLiteStatus Invoke(TfLiteContext* tfLiteContext, TfLiteNode* tfLiteNode);
77 
78     static TfLiteStatus VisitNode(DelegateData& delegateData,
79                                   TfLiteContext* tfLiteContext,
80                                   TfLiteRegistration* tfLiteRegistration,
81                                   TfLiteNode* tfLiteNode,
82                                   int nodeIndex);
83 
84 private:
ArmnnSubgraph(armnn::NetworkId networkId,armnn::IRuntime * runtime,std::vector<armnn::BindingPointInfo> & inputBindings,std::vector<armnn::BindingPointInfo> & outputBindings)85     ArmnnSubgraph(armnn::NetworkId networkId,
86                   armnn::IRuntime* runtime,
87                   std::vector<armnn::BindingPointInfo>& inputBindings,
88                   std::vector<armnn::BindingPointInfo>& outputBindings)
89         : m_NetworkId(networkId), m_Runtime(runtime), m_InputBindings(inputBindings), m_OutputBindings(outputBindings)
90     {}
91 
92     static TfLiteStatus AddInputLayer(DelegateData& delegateData,
93                                       TfLiteContext* tfLiteContext,
94                                       const TfLiteIntArray* inputs,
95                                       std::vector<armnn::BindingPointInfo>& inputBindings);
96 
97     static TfLiteStatus AddOutputLayer(DelegateData& delegateData,
98                                        TfLiteContext* tfLiteContext,
99                                        const TfLiteIntArray* outputs,
100                                        std::vector<armnn::BindingPointInfo>& outputBindings);
101 
102 
103     /// The Network Id
104     armnn::NetworkId m_NetworkId;
105     /// ArmNN Rumtime
106     armnn::IRuntime* m_Runtime;
107 
108     // Binding information for inputs and outputs
109     std::vector<armnn::BindingPointInfo> m_InputBindings;
110     std::vector<armnn::BindingPointInfo> m_OutputBindings;
111 
112 };
113 
114 } // armnnDelegate namespace
115 
116 
117