1 // 2 // Copyright © 2017,2022-2023 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 #pragma once 6 7 #include "armnnOnnxParser/IOnnxParser.hpp" 8 #include "google/protobuf/repeated_field.h" 9 #include <unordered_map> 10 11 #include <onnx/onnx.pb.h> 12 13 14 namespace armnn 15 { 16 class TensorInfo; 17 enum class ActivationFunction; 18 } 19 20 namespace armnnOnnxParser 21 { 22 23 using ModelPtr = std::unique_ptr<onnx::ModelProto>; 24 25 class OnnxParserImpl 26 { 27 28 using OperationParsingFunction = void(OnnxParserImpl::*)(const onnx::NodeProto& NodeProto); 29 30 public: 31 32 using GraphPtr = std::unique_ptr<onnx::GraphProto>; 33 34 /// Create the network from a protobuf binary file on disk 35 armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile); 36 37 /// Create the network from a protobuf binary file on disk, with inputShapes specified 38 armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile, 39 const std::map<std::string, armnn::TensorShape>& inputShapes); 40 41 /// Create the network from a protobuf binary 42 armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent); 43 44 /// Create the network from a protobuf binary, with inputShapes specified 45 armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent, 46 const std::map<std::string, armnn::TensorShape>& inputShapes); 47 48 /// Create the network from a protobuf text file on disk 49 armnn::INetworkPtr CreateNetworkFromTextFile(const char* graphFile); 50 51 /// Create the network from a protobuf text file on disk, with inputShapes specified 52 armnn::INetworkPtr CreateNetworkFromTextFile(const char* graphFile, 53 const std::map<std::string, armnn::TensorShape>& inputShapes); 54 55 /// Create the network directly from protobuf text in a string. Useful for debugging/testing 56 armnn::INetworkPtr CreateNetworkFromString(const std::string& protoText); 57 58 /// Create the network directly from protobuf text in a string, with inputShapes specified. 59 /// Useful for debugging/testing 60 armnn::INetworkPtr CreateNetworkFromString(const std::string& protoText, 61 const std::map<std::string, armnn::TensorShape>& inputShapes); 62 63 /// Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name 64 BindingPointInfo GetNetworkInputBindingInfo(const std::string& name) const; 65 66 /// Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name 67 BindingPointInfo GetNetworkOutputBindingInfo(const std::string& name) const; 68 69 public: 70 71 OnnxParserImpl(); 72 ~OnnxParserImpl() = default; 73 74 static ModelPtr LoadModelFromBinary(const std::vector<uint8_t>& binaryContent); 75 static ModelPtr LoadModelFromBinaryFile(const char * fileName); 76 static ModelPtr LoadModelFromTextFile(const char * fileName); 77 static ModelPtr LoadModelFromString(const std::string& inputString); 78 79 /// Retrieve inputs names 80 static std::vector<std::string> GetInputs(ModelPtr& model); 81 82 /// Retrieve outputs names 83 static std::vector<std::string> GetOutputs(ModelPtr& model); 84 85 /// Retrieve version in X.Y.Z form 86 static const std::string GetVersion(); 87 88 private: 89 90 /// Parses a ModelProto loaded into memory from one of the other CreateNetwork* 91 armnn::INetworkPtr CreateNetworkFromModel(onnx::ModelProto& model); 92 93 /// Parse every node and make the connection between the resulting tensors 94 void LoadGraph(); 95 96 void SetupInfo(const google::protobuf::RepeatedPtrField<onnx::ValueInfoProto >* list); 97 98 std::vector<armnn::TensorInfo> ComputeOutputInfo( 99 std::vector<std::string> outNames, 100 const armnn::IConnectableLayer* layer, 101 std::vector<armnn::TensorShape> inputShapes, 102 const onnx::TensorProto::DataType& type = onnx::TensorProto::FLOAT); 103 104 void DetectFullyConnected(); 105 106 template <typename Location> 107 void GetInputAndParam(const onnx::NodeProto& node, 108 std::string* inputName, 109 std::string* constName, 110 const Location& location); 111 112 template <typename Location> 113 void To1DTensor(const std::string &name, const Location& location); 114 115 //Broadcast Preparation functions 116 std::pair<std::string, std::string> AddPrepareBroadcast(const std::string& input0, const std::string& input1); 117 void PrependForBroadcast(const std::string& outputName, const std::string& input0, const std::string& input1); 118 119 void AddConvLayerWithDepthwiseConv(const onnx::NodeProto& node, const armnn::Convolution2dDescriptor& convDesc); 120 void AddFullyConnected(const onnx::NodeProto& matmulNode, const onnx::NodeProto* addNode = nullptr); 121 void AddPoolingLayer(const onnx::NodeProto& nodeProto, armnn::Pooling2dDescriptor& desc); 122 123 void CreateConstantLayer(const std::string& tensorName, const std::string& layerName); 124 void CreateInt64ConstantLayer(const std::string& tensorName, const std::string& layerName); 125 void CreateReshapeLayer(const std::string& inputName, 126 const std::string& outputName, 127 const std::string& layerName); 128 129 void ParseActivation(const onnx::NodeProto& nodeProto, const armnn::ActivationFunction func); 130 void ParseClip(const onnx::NodeProto& nodeProto); 131 void ParseSigmoid(const onnx::NodeProto& nodeProto); 132 void ParseTanh(const onnx::NodeProto& nodeProto); 133 void ParseRelu(const onnx::NodeProto& nodeProto); 134 void ParseLeakyRelu(const onnx::NodeProto& nodeProto); 135 136 void ParseAdd(const onnx::NodeProto& nodeProto); 137 void ParseAveragePool(const onnx::NodeProto& nodeProto); 138 void ParseBatchNormalization(const onnx::NodeProto& node); 139 void ParseConcat(const onnx::NodeProto& nodeProto); 140 void ParseConstant(const onnx::NodeProto& nodeProto); 141 void ParseConv(const onnx::NodeProto& nodeProto); 142 void ParseFlatten(const onnx::NodeProto& node); 143 void ParseGather(const onnx::NodeProto& node); 144 void ParseGemm(const onnx::NodeProto& node); 145 void ParseGlobalAveragePool(const onnx::NodeProto& node); 146 void ParseMaxPool(const onnx::NodeProto& nodeProto); 147 void ParseShape(const onnx::NodeProto& node); 148 void ParseReshape(const onnx::NodeProto& nodeProto); 149 void ParseUnsqueeze(const onnx::NodeProto& nodeProto); 150 151 void RegisterInputSlot(armnn::IConnectableLayer* layer, 152 const std::string& tensorId, 153 unsigned int slotIndex); 154 void RegisterInputSlots(armnn::IConnectableLayer* layer, const std::vector<std::string>& tensorIndexes); 155 void RegisterOutputSlots(armnn::IConnectableLayer* layer, const std::vector<std::string>& tensorIndexes); 156 157 void SetupInputLayers(); 158 void SetupOutputLayers(); 159 160 void ResetParser(); 161 void Cleanup(); 162 163 std::pair<armnn::ConstTensor, std::unique_ptr<float[]>> 164 CreateConstTensor(const std::string name, 165 armnn::Optional<armnn::PermutationVector&> permutationVector = armnn::EmptyOptional()); 166 167 std::pair<armnn::ConstTensor, std::unique_ptr<int32_t[]>> 168 CreateInt64ConstTensor(const std::string name, 169 armnn::Optional<armnn::PermutationVector&> permutationVector = armnn::EmptyOptional()); 170 171 template <typename TypeList, typename Location> 172 void ValidateInputs(const onnx::NodeProto& node, 173 TypeList validInputs, 174 const Location& location); 175 176 /// The network we're building. Gets cleared after it is passed to the user 177 armnn::INetworkPtr m_Network; 178 179 /// Ptr to the graph we're building the network from 180 GraphPtr m_Graph; 181 182 /// Map of the information for every tensor 183 struct OnnxTensor 184 { 185 std::unique_ptr<armnn::TensorInfo> m_info; 186 std::unique_ptr<const onnx::TensorProto> m_tensor; 187 onnx::TensorProto::DataType m_dtype; 188 OnnxTensorarmnnOnnxParser::OnnxParserImpl::OnnxTensor189 OnnxTensor() : m_info(nullptr), m_tensor(nullptr), m_dtype(onnx::TensorProto::FLOAT) { } isConstantarmnnOnnxParser::OnnxParserImpl::OnnxTensor190 bool isConstant() { return m_tensor != nullptr; } 191 }; 192 193 std::unordered_map<std::string, OnnxTensor> m_TensorsInfo; 194 195 /// map of onnx operation names to parsing member functions 196 static const std::map<std::string, OperationParsingFunction> m_ParserFunctions; 197 198 /// A mapping of an output slot to each of the input slots it should be connected to 199 /// The outputSlot is from the layer that creates this tensor as one of its outputs 200 /// The inputSlots are from the layers that use this tensor as one of their inputs 201 struct TensorSlots 202 { 203 armnn::IOutputSlot* outputSlot; 204 std::vector<armnn::IInputSlot*> inputSlots; 205 TensorSlotsarmnnOnnxParser::OnnxParserImpl::TensorSlots206 TensorSlots() : outputSlot(nullptr) { } 207 }; 208 /// Map of the tensor names to their connections for the connections of the layers of the graph 209 std::unordered_map<std::string, TensorSlots> m_TensorConnections; 210 211 /// Map of the tensor names to their node and index in graph.node() 212 std::unordered_map<std::string, std::pair<const onnx::NodeProto*, int>> m_OutputsMap; 213 214 /// Number of times a specific node (identified by its index number) was used as input 215 /// and list of the nodes it was fused with 216 struct UsageSummary 217 { 218 std::vector<size_t> fusedWithNodes; 219 size_t inputForNodes; 220 UsageSummaryarmnnOnnxParser::OnnxParserImpl::UsageSummary221 UsageSummary() : fusedWithNodes({}), inputForNodes(0) { } 222 223 }; 224 225 std::vector<UsageSummary> m_OutputsFusedAndUsed; 226 227 std::map<std::string, armnn::TensorShape> m_InputShapes; 228 229 std::unordered_map<std::string, armnn::TensorInfo> m_InputInfos; 230 231 std::unordered_map<std::string, armnn::TensorInfo> m_OutputInfos; 232 233 }; 234 } 235