• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "armnn/INetwork.hpp"
8 #include "armnnTfLiteParser/ITfLiteParser.hpp"
9 #include "armnn/Types.hpp"
10 
11 #include <schema_generated.h>
12 #include <functional>
13 #include <unordered_map>
14 #include <vector>
15 
16 namespace armnnTfLiteParser
17 {
18 
19 class TfLiteParser : public ITfLiteParser
20 {
21 public:
22     // Shorthands for TfLite types
23     using ModelPtr = std::unique_ptr<tflite::ModelT>;
24     using SubgraphPtr = std::unique_ptr<tflite::SubGraphT>;
25     using OperatorPtr = std::unique_ptr<tflite::OperatorT>;
26     using OperatorCodePtr = std::unique_ptr<tflite::OperatorCodeT>;
27     using TensorPtr = std::unique_ptr<tflite::TensorT>;
28     using TensorRawPtr = const tflite::TensorT *;
29     using TensorRawPtrVector = std::vector<TensorRawPtr>;
30     using TensorIdRawPtr = std::pair<size_t, TensorRawPtr>;
31     using TensorIdRawPtrVector = std::vector<TensorIdRawPtr>;
32     using BufferPtr = std::unique_ptr<tflite::BufferT>;
33     using BufferRawPtr = const tflite::BufferT *;
34 
35 public:
36     /// Create the network from a flatbuffers binary file on disk
37     virtual armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile) override;
38 
39     /// Create the network from a flatbuffers binary
40     virtual armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t> & binaryContent) override;
41 
42 
43     /// Retrieve binding info (layer id and tensor info) for the network input identified by
44     /// the given layer name and subgraph id
45     virtual BindingPointInfo GetNetworkInputBindingInfo(size_t subgraphId,
46                                                         const std::string& name) const override;
47 
48     /// Retrieve binding info (layer id and tensor info) for the network output identified by
49     /// the given layer name and subgraph id
50     virtual BindingPointInfo GetNetworkOutputBindingInfo(size_t subgraphId,
51                                                          const std::string& name) const override;
52 
53     /// Return the number of subgraphs in the parsed model
54     virtual size_t GetSubgraphCount() const override;
55 
56     /// Return the input tensor names for a given subgraph
57     virtual std::vector<std::string> GetSubgraphInputTensorNames(size_t subgraphId) const override;
58 
59     /// Return the output tensor names for a given subgraph
60     virtual std::vector<std::string> GetSubgraphOutputTensorNames(size_t subgraphId) const override;
61 
62     TfLiteParser(const armnn::Optional<ITfLiteParser::TfLiteParserOptions>& options = armnn::EmptyOptional());
~TfLiteParser()63     virtual ~TfLiteParser() {}
64 
65 public:
66     // testable helpers
67     static ModelPtr LoadModelFromFile(const char * fileName);
68     static ModelPtr LoadModelFromBinary(const uint8_t * binaryContent, size_t len);
69     static TensorRawPtrVector GetInputs(const ModelPtr & model, size_t subgraphIndex, size_t operatorIndex);
70     static TensorRawPtrVector GetOutputs(const ModelPtr & model, size_t subgraphIndex, size_t operatorIndex);
71     static TensorIdRawPtrVector GetSubgraphInputs(const ModelPtr & model, size_t subgraphIndex);
72     static TensorIdRawPtrVector GetSubgraphOutputs(const ModelPtr & model, size_t subgraphIndex);
73     static std::vector<int32_t>& GetInputTensorIds(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex);
74     static std::vector<int32_t>& GetOutputTensorIds(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex);
75 
76     static BufferRawPtr GetBuffer(const ModelPtr& model, size_t bufferIndex);
77     static armnn::TensorInfo OutputShapeOfSqueeze(const std::vector<uint32_t> & squeezeDims,
78                                                   const armnn::TensorInfo & inputTensorInfo);
79     static armnn::TensorInfo OutputShapeOfReshape(const armnn::TensorInfo & inputTensorInfo,
80                                                   const std::vector<int32_t> & targetDimsIn);
81 
82 private:
83     // No copying allowed until it is wanted and properly implemented
84     TfLiteParser(const TfLiteParser &) = delete;
85     TfLiteParser & operator=(const TfLiteParser &) = delete;
86 
87     /// Create the network from an already loaded flatbuffers model
88     armnn::INetworkPtr CreateNetworkFromModel();
89 
90     // signature for the parser functions
91     using OperatorParsingFunction = void(TfLiteParser::*)(size_t subgraphIndex, size_t operatorIndex);
92 
93     void ParseCustomOperator(size_t subgraphIndex, size_t operatorIndex);
94     void ParseUnsupportedOperator(size_t subgraphIndex, size_t operatorIndex);
95 
96     void ParseActivation(size_t subgraphIndex, size_t operatorIndex, armnn::ActivationFunction activationType);
97     void ParseAdd(size_t subgraphIndex, size_t operatorIndex);
98     void ParseAveragePool2D(size_t subgraphIndex, size_t operatorIndex);
99     void ParseBatchToSpaceND(size_t subgraphIndex, size_t operatorIndex);
100     void ParseConcatenation(size_t subgraphIndex, size_t operatorIndex);
101     void ParseConv2D(size_t subgraphIndex, size_t operatorIndex);
102     void ParseDepthwiseConv2D(size_t subgraphIndex, size_t operatorIndex);
103     void ParseDequantize(size_t subgraphIndex, size_t operatorIndex);
104     void ParseDetectionPostProcess(size_t subgraphIndex, size_t operatorIndex);
105     void ParseExp(size_t subgraphIndex, size_t operatorIndex);
106     void ParseFullyConnected(size_t subgraphIndex, size_t operatorIndex);
107     void ParseHardSwish(size_t subgraphIndex, size_t operatorIndex);
108     void ParseLeakyRelu(size_t subgraphIndex, size_t operatorIndex);
109     void ParseLogistic(size_t subgraphIndex, size_t operatorIndex);
110     void ParseL2Normalization(size_t subgraphIndex, size_t operatorIndex);
111     void ParseMaxPool2D(size_t subgraphIndex, size_t operatorIndex);
112     void ParseMaximum(size_t subgraphIndex, size_t operatorIndex);
113     void ParseMean(size_t subgraphIndex, size_t operatorIndex);
114     void ParseMinimum(size_t subgraphIndex, size_t operatorIndex);
115     void ParseMul(size_t subgraphIndex, size_t operatorIndex);
116     void ParseNeg(size_t subgraphIndex, size_t operatorIndex);
117     void ParsePack(size_t subgraphIndex, size_t operatorIndex);
118     void ParsePad(size_t subgraphIndex, size_t operatorIndex);
119     void ParsePool(size_t subgraphIndex, size_t operatorIndex, armnn::PoolingAlgorithm algorithm);
120     void ParseQuantize(size_t subgraphIndex, size_t operatorIndex);
121     void ParseRelu(size_t subgraphIndex, size_t operatorIndex);
122     void ParseRelu6(size_t subgraphIndex, size_t operatorIndex);
123     void ParseReshape(size_t subgraphIndex, size_t operatorIndex);
124     void ParseResize(size_t subgraphIndex, size_t operatorIndex, armnn::ResizeMethod resizeMethod);
125     void ParseResizeBilinear(size_t subgraphIndex, size_t operatorIndex);
126     void ParseResizeNearestNeighbor(size_t subgraphIndex, size_t operatorIndex);
127     void ParseSlice(size_t subgraphIndex, size_t operatorIndex);
128     void ParseSoftmax(size_t subgraphIndex, size_t operatorIndex);
129     void ParseSpaceToBatchND(size_t subgraphIndex, size_t operatorIndex);
130     void ParseSplit(size_t subgraphIndex, size_t operatorIndex);
131     void ParseSplitV(size_t subgraphIndex, size_t operatorIndex);
132     void ParseSqueeze(size_t subgraphIndex, size_t operatorIndex);
133     void ParseStridedSlice(size_t subgraphIndex, size_t operatorIndex);
134     void ParseSub(size_t subgraphIndex, size_t operatorIndex);
135     void ParseDiv(size_t subgraphIndex, size_t operatorIndex);
136     void ParseTanH(size_t subgraphIndex, size_t operatorIndex);
137     void ParseTranspose(size_t subgraphIndex, size_t operatorIndex);
138     void ParseTransposeConv(size_t subgraphIndex, size_t operatorIndex);
139     void ParseUnpack(size_t subgraphIndex, size_t operatorIndex);
140     void ParseArgMax(size_t subgraphIndex, size_t operatorIndex);
141 
142     void RegisterProducerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IOutputSlot* slot);
143     void RegisterConsumerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IInputSlot* slot);
144     void RegisterInputSlots(size_t subgraphIndex,
145                             size_t operatorIndex,
146                             armnn::IConnectableLayer* layer,
147                             const std::vector<unsigned int>& tensorIndexes);
148     void RegisterOutputSlots(size_t subgraphIndex,
149                              size_t operatorIndex,
150                              armnn::IConnectableLayer* layer,
151                              const std::vector<unsigned int>& tensorIndexes);
152 
153     void SetupInputLayers(size_t subgraphIndex);
154     void SetupOutputLayers(size_t subgraphIndex);
155     void SetupConstantLayers(size_t subgraphIndex);
156 
157     void ResetParser();
158 
159     void AddBroadcastReshapeLayer(size_t subgraphIndex,
160                                   size_t operatorIndex,
161                                   armnn::IConnectableLayer* layer);
162 
163     /// Attach an activation layer to the one passed as a parameter
164     armnn::IConnectableLayer* AddFusedActivationLayer(armnn::IConnectableLayer* layer,
165                                                       unsigned int outputSlot,
166                                                       tflite::ActivationFunctionType activationType);
167 
168     // SupportedDataStorage's purpose is to hold data till we pass over to the network.
169     // We don't care about the content, and we want a single datatype to simplify the code.
170     struct SupportedDataStorage
171     {
172     public:
173         // Convenience constructors
174         SupportedDataStorage(std::unique_ptr<float[]>&&   data);
175         SupportedDataStorage(std::unique_ptr<uint8_t[]>&& data);
176         SupportedDataStorage(std::unique_ptr<int8_t[]>&&  data);
177         SupportedDataStorage(std::unique_ptr<int32_t[]>&& data);
178 
179     private:
180         // Pointers to the data buffers
181         std::unique_ptr<float[]>   m_FloatData;
182         std::unique_ptr<uint8_t[]> m_Uint8Data;
183         std::unique_ptr<int8_t[]>  m_Int8Data;
184         std::unique_ptr<int32_t[]> m_Int32Data;
185     };
186 
187 
188     template<typename T>
189     std::pair<armnn::ConstTensor, TfLiteParser::SupportedDataStorage>
190     CreateConstTensorAndStoreData(TfLiteParser::BufferRawPtr bufferPtr,
191                                   TfLiteParser::TensorRawPtr tensorPtr,
192                                   armnn::TensorInfo& tensorInfo,
193                                   armnn::Optional<armnn::PermutationVector&> permutationVector);
194 
195     std::pair<armnn::ConstTensor, SupportedDataStorage>
196     CreateConstTensor(TensorRawPtr tensorPtr,
197                       armnn::TensorInfo& tensorInfo,
198                       armnn::Optional<armnn::PermutationVector&> permutationVector);
199 
200     // Settings for configuring the TfLiteParser
201     armnn::Optional<ITfLiteParser::TfLiteParserOptions> m_Options;
202 
203     /// The network we're building. Gets cleared after it is passed to the user
204     armnn::INetworkPtr                    m_Network;
205     ModelPtr                              m_Model;
206 
207     std::vector<OperatorParsingFunction>                     m_ParserFunctions;
208     std::unordered_map<std::string, OperatorParsingFunction> m_CustomParserFunctions;
209 
210     /// A mapping of an output slot to each of the input slots it should be connected to
211     /// The outputSlot is from the layer that creates this tensor as one of its ouputs
212     /// The inputSlots are from the layers that use this tensor as one of their inputs
213     struct TensorSlots
214     {
215         armnn::IOutputSlot* outputSlot;
216         std::vector<armnn::IInputSlot*> inputSlots;
217 
TensorSlotsarmnnTfLiteParser::TfLiteParser::TensorSlots218         TensorSlots() : outputSlot(nullptr) { }
219     };
220     typedef std::vector<TensorSlots> TensorConnections;
221     /// Connections for tensors in each subgraph
222     /// The first index is the subgraph ID, the second index is the tensor ID
223     std::vector<TensorConnections> m_SubgraphConnections;
224 
225     /// This is used in case that the model does not speciry the output.
226     /// The shape can be calculated from the options.
227     std::vector<std::vector<unsigned int>> m_OverridenOutputShapes;
228 };
229 
230 }
231