• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017,2022-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "armnnOnnxParser/IOnnxParser.hpp"
8 #include "google/protobuf/repeated_field.h"
9 #include <unordered_map>
10 
11 #include <onnx/onnx.pb.h>
12 
13 
14 namespace armnn
15 {
16 class TensorInfo;
17 enum class ActivationFunction;
18 }
19 
20 namespace armnnOnnxParser
21 {
22 
23 using ModelPtr = std::unique_ptr<onnx::ModelProto>;
24 
25 class OnnxParserImpl
26 {
27 
28 using OperationParsingFunction = void(OnnxParserImpl::*)(const onnx::NodeProto& NodeProto);
29 
30 public:
31 
32     using GraphPtr = std::unique_ptr<onnx::GraphProto>;
33 
34     /// Create the network from a protobuf binary file on disk
35     armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile);
36 
37     /// Create the network from a protobuf binary file on disk, with inputShapes specified
38     armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile,
39                                                    const std::map<std::string, armnn::TensorShape>& inputShapes);
40 
41     /// Create the network from a protobuf binary
42     armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent);
43 
44     /// Create the network from a protobuf binary, with inputShapes specified
45     armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent,
46                                                const std::map<std::string, armnn::TensorShape>& inputShapes);
47 
48     /// Create the network from a protobuf text file on disk
49     armnn::INetworkPtr CreateNetworkFromTextFile(const char* graphFile);
50 
51     /// Create the network from a protobuf text file on disk, with inputShapes specified
52     armnn::INetworkPtr CreateNetworkFromTextFile(const char* graphFile,
53                                                  const std::map<std::string, armnn::TensorShape>& inputShapes);
54 
55     /// Create the network directly from protobuf text in a string. Useful for debugging/testing
56     armnn::INetworkPtr CreateNetworkFromString(const std::string& protoText);
57 
58      /// Create the network directly from protobuf text in a string, with inputShapes specified.
59      /// Useful for debugging/testing
60     armnn::INetworkPtr CreateNetworkFromString(const std::string& protoText,
61                                                const std::map<std::string, armnn::TensorShape>& inputShapes);
62 
63     /// Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name
64     BindingPointInfo GetNetworkInputBindingInfo(const std::string& name) const;
65 
66     /// Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name
67     BindingPointInfo GetNetworkOutputBindingInfo(const std::string& name) const;
68 
69 public:
70 
71     OnnxParserImpl();
72     ~OnnxParserImpl() = default;
73 
74     static ModelPtr LoadModelFromBinary(const std::vector<uint8_t>& binaryContent);
75     static ModelPtr LoadModelFromBinaryFile(const char * fileName);
76     static ModelPtr LoadModelFromTextFile(const char * fileName);
77     static ModelPtr LoadModelFromString(const std::string& inputString);
78 
79     /// Retrieve inputs names
80     static std::vector<std::string> GetInputs(ModelPtr& model);
81 
82     /// Retrieve outputs names
83     static std::vector<std::string> GetOutputs(ModelPtr& model);
84 
85     /// Retrieve version in X.Y.Z form
86     static const std::string GetVersion();
87 
88 private:
89 
90     /// Parses a ModelProto loaded into memory from one of the other CreateNetwork*
91     armnn::INetworkPtr CreateNetworkFromModel(onnx::ModelProto& model);
92 
93     /// Parse every node and make the connection between the resulting tensors
94     void LoadGraph();
95 
96     void SetupInfo(const google::protobuf::RepeatedPtrField<onnx::ValueInfoProto >* list);
97 
98     std::vector<armnn::TensorInfo> ComputeOutputInfo(
99         std::vector<std::string> outNames,
100         const armnn::IConnectableLayer* layer,
101         std::vector<armnn::TensorShape> inputShapes,
102         const onnx::TensorProto::DataType& type = onnx::TensorProto::FLOAT);
103 
104     void DetectFullyConnected();
105 
106     template <typename Location>
107     void GetInputAndParam(const onnx::NodeProto& node,
108                           std::string* inputName,
109                           std::string* constName,
110                           const Location& location);
111 
112     template <typename Location>
113     void To1DTensor(const std::string &name, const Location& location);
114 
115     //Broadcast Preparation functions
116     std::pair<std::string, std::string> AddPrepareBroadcast(const std::string& input0, const std::string& input1);
117     void PrependForBroadcast(const std::string& outputName, const std::string& input0, const std::string& input1);
118 
119     void AddConvLayerWithDepthwiseConv(const onnx::NodeProto& node, const armnn::Convolution2dDescriptor& convDesc);
120     void AddFullyConnected(const onnx::NodeProto& matmulNode, const onnx::NodeProto* addNode = nullptr);
121     void AddPoolingLayer(const onnx::NodeProto& nodeProto, armnn::Pooling2dDescriptor& desc);
122 
123     void CreateConstantLayer(const std::string& tensorName, const std::string& layerName);
124     void CreateInt64ConstantLayer(const std::string& tensorName, const std::string& layerName);
125     void CreateReshapeLayer(const std::string& inputName,
126                             const std::string& outputName,
127                             const std::string& layerName);
128 
129     void ParseActivation(const onnx::NodeProto& nodeProto, const armnn::ActivationFunction func);
130     void ParseClip(const onnx::NodeProto& nodeProto);
131     void ParseSigmoid(const onnx::NodeProto& nodeProto);
132     void ParseTanh(const onnx::NodeProto& nodeProto);
133     void ParseRelu(const onnx::NodeProto& nodeProto);
134     void ParseLeakyRelu(const onnx::NodeProto& nodeProto);
135 
136     void ParseAdd(const onnx::NodeProto& nodeProto);
137     void ParseAveragePool(const onnx::NodeProto& nodeProto);
138     void ParseBatchNormalization(const onnx::NodeProto& node);
139     void ParseConcat(const onnx::NodeProto& nodeProto);
140     void ParseConstant(const onnx::NodeProto& nodeProto);
141     void ParseConv(const onnx::NodeProto& nodeProto);
142     void ParseFlatten(const onnx::NodeProto& node);
143     void ParseGather(const onnx::NodeProto& node);
144     void ParseGemm(const onnx::NodeProto& node);
145     void ParseGlobalAveragePool(const onnx::NodeProto& node);
146     void ParseMaxPool(const onnx::NodeProto& nodeProto);
147     void ParseShape(const onnx::NodeProto& node);
148     void ParseReshape(const onnx::NodeProto& nodeProto);
149     void ParseUnsqueeze(const onnx::NodeProto& nodeProto);
150 
151     void RegisterInputSlot(armnn::IConnectableLayer* layer,
152                            const std::string& tensorId,
153                            unsigned int slotIndex);
154     void RegisterInputSlots(armnn::IConnectableLayer* layer, const std::vector<std::string>& tensorIndexes);
155     void RegisterOutputSlots(armnn::IConnectableLayer* layer, const std::vector<std::string>& tensorIndexes);
156 
157     void SetupInputLayers();
158     void SetupOutputLayers();
159 
160     void ResetParser();
161     void Cleanup();
162 
163     std::pair<armnn::ConstTensor, std::unique_ptr<float[]>>
164     CreateConstTensor(const std::string name,
165                       armnn::Optional<armnn::PermutationVector&> permutationVector = armnn::EmptyOptional());
166 
167     std::pair<armnn::ConstTensor, std::unique_ptr<int32_t[]>>
168     CreateInt64ConstTensor(const std::string name,
169                            armnn::Optional<armnn::PermutationVector&> permutationVector = armnn::EmptyOptional());
170 
171     template <typename TypeList, typename Location>
172     void ValidateInputs(const onnx::NodeProto& node,
173                         TypeList validInputs,
174                         const Location& location);
175 
176     /// The network we're building. Gets cleared after it is passed to the user
177     armnn::INetworkPtr m_Network;
178 
179     /// Ptr to the graph we're building the network from
180     GraphPtr m_Graph;
181 
182     /// Map of the information for every tensor
183     struct OnnxTensor
184     {
185         std::unique_ptr<armnn::TensorInfo>          m_info;
186         std::unique_ptr<const onnx::TensorProto>    m_tensor;
187         onnx::TensorProto::DataType                 m_dtype;
188 
OnnxTensorarmnnOnnxParser::OnnxParserImpl::OnnxTensor189         OnnxTensor() : m_info(nullptr), m_tensor(nullptr), m_dtype(onnx::TensorProto::FLOAT) { }
isConstantarmnnOnnxParser::OnnxParserImpl::OnnxTensor190         bool isConstant() { return m_tensor != nullptr; }
191     };
192 
193     std::unordered_map<std::string, OnnxTensor> m_TensorsInfo;
194 
195     /// map of onnx operation names to parsing member functions
196     static const std::map<std::string, OperationParsingFunction> m_ParserFunctions;
197 
198     /// A mapping of an output slot to each of the input slots it should be connected to
199     /// The outputSlot is from the layer that creates this tensor as one of its outputs
200     /// The inputSlots are from the layers that use this tensor as one of their inputs
201     struct TensorSlots
202     {
203         armnn::IOutputSlot* outputSlot;
204         std::vector<armnn::IInputSlot*> inputSlots;
205 
TensorSlotsarmnnOnnxParser::OnnxParserImpl::TensorSlots206         TensorSlots() : outputSlot(nullptr) { }
207     };
208     /// Map of the tensor names to their connections for the connections of the layers of the graph
209     std::unordered_map<std::string, TensorSlots> m_TensorConnections;
210 
211     /// Map of the tensor names to their node and index in graph.node()
212     std::unordered_map<std::string, std::pair<const onnx::NodeProto*, int>> m_OutputsMap;
213 
214     /// Number of times a specific node (identified by its index number) was used as input
215     /// and list of the nodes it was fused with
216     struct UsageSummary
217     {
218         std::vector<size_t> fusedWithNodes;
219         size_t inputForNodes;
220 
UsageSummaryarmnnOnnxParser::OnnxParserImpl::UsageSummary221         UsageSummary() : fusedWithNodes({}), inputForNodes(0) { }
222 
223     };
224 
225     std::vector<UsageSummary> m_OutputsFusedAndUsed;
226 
227     std::map<std::string, armnn::TensorShape> m_InputShapes;
228 
229     std::unordered_map<std::string, armnn::TensorInfo> m_InputInfos;
230 
231     std::unordered_map<std::string, armnn::TensorInfo> m_OutputInfos;
232 
233 };
234 }
235