1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 #pragma once 6 7 #include "armnnOnnxParser/IOnnxParser.hpp" 8 #include "google/protobuf/repeated_field.h" 9 #include <unordered_map> 10 11 #include <onnx/onnx.pb.h> 12 13 14 namespace armnn 15 { 16 class TensorInfo; 17 enum class ActivationFunction; 18 } 19 20 namespace armnnOnnxParser 21 { 22 23 using ModelPtr = std::unique_ptr<onnx::ModelProto>; 24 25 class OnnxParser : public IOnnxParser 26 { 27 28 using OperationParsingFunction = void(OnnxParser::*)(const onnx::NodeProto& NodeProto); 29 30 public: 31 32 using GraphPtr = std::unique_ptr<onnx::GraphProto>; 33 34 /// Create the network from a protobuf binary file on disk 35 virtual armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile) override; 36 37 /// Create the network from a protobuf text file on disk 38 virtual armnn::INetworkPtr CreateNetworkFromTextFile(const char* graphFile) override; 39 40 /// Create the network directly from protobuf text in a string. Useful for debugging/testing 41 virtual armnn::INetworkPtr CreateNetworkFromString(const std::string& protoText) override; 42 43 /// Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name 44 virtual BindingPointInfo GetNetworkInputBindingInfo(const std::string& name) const override; 45 46 /// Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name 47 virtual BindingPointInfo GetNetworkOutputBindingInfo(const std::string& name) const override; 48 49 public: 50 51 OnnxParser(); 52 53 static ModelPtr LoadModelFromBinaryFile(const char * fileName); 54 static ModelPtr LoadModelFromTextFile(const char * fileName); 55 static ModelPtr LoadModelFromString(const std::string& inputString); 56 57 /// Retrieve inputs names 58 static std::vector<std::string> GetInputs(ModelPtr& model); 59 60 /// Retrieve outputs names 61 static std::vector<std::string> GetOutputs(ModelPtr& model); 62 63 private: 64 65 /// Parses a ModelProto loaded into memory from one of the other CreateNetwork* 66 armnn::INetworkPtr CreateNetworkFromModel(onnx::ModelProto& model); 67 68 /// Parse every node and make the connection between the resulting tensors 69 void LoadGraph(); 70 71 void SetupInfo(const google::protobuf::RepeatedPtrField<onnx::ValueInfoProto >* list); 72 73 std::vector<armnn::TensorInfo> ComputeOutputInfo(std::vector<std::string> outNames, 74 const armnn::IConnectableLayer* layer, 75 std::vector<armnn::TensorShape> inputShapes); 76 77 void DetectFullyConnected(); 78 79 template <typename Location> 80 void GetInputAndParam(const onnx::NodeProto& node, 81 std::string* inputName, 82 std::string* constName, 83 const Location& location); 84 85 template <typename Location> 86 void To1DTensor(const std::string &name, const Location& location); 87 88 //Broadcast Preparation functions 89 std::pair<std::string, std::string> AddPrepareBroadcast(const std::string& input0, const std::string& input1); 90 void PrependForBroadcast(const std::string& outputName, const std::string& input0, const std::string& input1); 91 92 void AddConvLayerWithDepthwiseConv(const onnx::NodeProto& node, const armnn::Convolution2dDescriptor& convDesc); 93 void AddFullyConnected(const onnx::NodeProto& matmulNode, const onnx::NodeProto* addNode = nullptr); 94 void AddPoolingLayer(const onnx::NodeProto& nodeProto, armnn::Pooling2dDescriptor& desc); 95 96 void CreateConstantLayer(const std::string& tensorName, const std::string& layerName); 97 void CreateReshapeLayer(const std::string& inputName, 98 const std::string& outputName, 99 const std::string& layerName); 100 101 void ParseActivation(const onnx::NodeProto& nodeProto, const armnn::ActivationFunction func); 102 void ParseClip(const onnx::NodeProto& nodeProto); 103 void ParseSigmoid(const onnx::NodeProto& nodeProto); 104 void ParseTanh(const onnx::NodeProto& nodeProto); 105 void ParseRelu(const onnx::NodeProto& nodeProto); 106 void ParseLeakyRelu(const onnx::NodeProto& nodeProto); 107 108 void ParseAdd(const onnx::NodeProto& nodeProto); 109 void ParseAveragePool(const onnx::NodeProto& nodeProto); 110 void ParseBatchNormalization(const onnx::NodeProto& node); 111 void ParseConstant(const onnx::NodeProto& nodeProto); 112 void ParseConv(const onnx::NodeProto& nodeProto); 113 void ParseFlatten(const onnx::NodeProto& node); 114 void ParseGlobalAveragePool(const onnx::NodeProto& node); 115 void ParseMaxPool(const onnx::NodeProto& nodeProto); 116 void ParseReshape(const onnx::NodeProto& nodeProto); 117 118 void RegisterInputSlots(armnn::IConnectableLayer* layer, const std::vector<std::string>& tensorIndexes); 119 void RegisterOutputSlots(armnn::IConnectableLayer* layer, const std::vector<std::string>& tensorIndexes); 120 121 void SetupInputLayers(); 122 void SetupOutputLayers(); 123 124 void ResetParser(); 125 void Cleanup(); 126 127 std::pair<armnn::ConstTensor, std::unique_ptr<float[]>> CreateConstTensor(const std::string name); 128 129 template <typename TypeList, typename Location> 130 void ValidateInputs(const onnx::NodeProto& node, 131 TypeList validInputs, 132 const Location& location); 133 134 /// The network we're building. Gets cleared after it is passed to the user 135 armnn::INetworkPtr m_Network; 136 137 /// Ptr to the graph we're building the network from 138 GraphPtr m_Graph; 139 140 /// Map of the information for every tensor 141 struct OnnxTensor 142 { 143 std::unique_ptr<armnn::TensorInfo> m_info; 144 std::unique_ptr<const onnx::TensorProto> m_tensor; 145 onnx::TensorProto::DataType m_dtype; 146 OnnxTensorarmnnOnnxParser::OnnxParser::OnnxTensor147 OnnxTensor() : m_info(nullptr), m_tensor(nullptr), m_dtype(onnx::TensorProto::FLOAT) { } isConstantarmnnOnnxParser::OnnxParser::OnnxTensor148 bool isConstant() { return m_tensor != nullptr; } 149 }; 150 151 std::unordered_map<std::string, OnnxTensor> m_TensorsInfo; 152 153 /// map of onnx operation names to parsing member functions 154 static const std::map<std::string, OperationParsingFunction> m_ParserFunctions; 155 156 /// A mapping of an output slot to each of the input slots it should be connected to 157 /// The outputSlot is from the layer that creates this tensor as one of its ouputs 158 /// The inputSlots are from the layers that use this tensor as one of their inputs 159 struct TensorSlots 160 { 161 armnn::IOutputSlot* outputSlot; 162 std::vector<armnn::IInputSlot*> inputSlots; 163 TensorSlotsarmnnOnnxParser::OnnxParser::TensorSlots164 TensorSlots() : outputSlot(nullptr) { } 165 }; 166 /// Map of the tensor names to their connections for the connections of the layers of the graph 167 std::unordered_map<std::string, TensorSlots> m_TensorConnections; 168 169 /// Map of the tensor names to their node and index in graph.node() 170 std::unordered_map<std::string, std::pair<const onnx::NodeProto*, int>> m_OutputsMap; 171 172 /// Number of times a specific node (identified by his index number) was used as input 173 /// and list of the nodes it was fused with 174 struct UsageSummary 175 { 176 std::vector<size_t> fusedWithNodes; 177 size_t inputForNodes; 178 UsageSummaryarmnnOnnxParser::OnnxParser::UsageSummary179 UsageSummary() : fusedWithNodes({}), inputForNodes(0) { } 180 181 }; 182 183 std::vector<UsageSummary> m_OutputsFusedAndUsed; 184 185 }; 186 } 187