1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 #pragma once 6 7 #include "armnnTfParser/ITfParser.hpp" 8 9 #include "armnn/Types.hpp" 10 #include "armnn/Tensor.hpp" 11 #include "armnn/INetwork.hpp" 12 13 #include <list> 14 #include <map> 15 #include <memory> 16 #include <unordered_map> 17 #include <utility> 18 #include <vector> 19 20 namespace armnn 21 { 22 class TensorInfo; 23 } 24 25 namespace tensorflow 26 { 27 class GraphDef; 28 class NodeDef; 29 } 30 31 namespace armnnTfParser 32 { 33 34 class ParsedTfOperation; 35 using ParsedTfOperationPtr = std::unique_ptr<ParsedTfOperation>; 36 37 /// 38 /// WithOutputTensorIndex wraps a value and an index. The purpose of 39 /// this template is to signify that, in Tensorflow, the input name of 40 /// a layer has the convention of 'inputTensorName:#index', where the 41 /// #index can be omitted and it implicitly means the 0 output of 42 /// the referenced layer. By supporting this notation we can handle 43 /// layers with multiple outputs, such as Split. 44 /// 45 template <typename T> 46 struct WithOutputTensorIndex 47 { 48 T m_IndexedValue; 49 unsigned int m_Index; 50 WithOutputTensorIndexarmnnTfParser::WithOutputTensorIndex51 WithOutputTensorIndex(const T & value, unsigned int index) 52 : m_IndexedValue{value} 53 , m_Index{index} {} 54 WithOutputTensorIndexarmnnTfParser::WithOutputTensorIndex55 WithOutputTensorIndex(T && value, unsigned int index) 56 : m_IndexedValue{value} 57 , m_Index{index} {} 58 }; 59 60 using OutputOfParsedTfOperation = WithOutputTensorIndex<ParsedTfOperation *>; 61 using OutputOfConstNodeDef = WithOutputTensorIndex<const tensorflow::NodeDef*>; 62 using OutputId = WithOutputTensorIndex<std::string>; 63 64 class TfParser : public ITfParser 65 { 66 public: 67 /// Creates the network from a protobuf text file on the disk. 68 virtual armnn::INetworkPtr CreateNetworkFromTextFile( 69 const char* graphFile, 70 const std::map<std::string, armnn::TensorShape>& inputShapes, 71 const std::vector<std::string>& requestedOutputs) override; 72 73 /// Creates the network from a protobuf binary file on the disk. 74 virtual armnn::INetworkPtr CreateNetworkFromBinaryFile( 75 const char* graphFile, 76 const std::map<std::string, armnn::TensorShape>& inputShapes, 77 const std::vector<std::string>& requestedOutputs) override; 78 79 /// Creates the network directly from protobuf text in a string. Useful for debugging/testing. 80 virtual armnn::INetworkPtr CreateNetworkFromString( 81 const char* protoText, 82 const std::map<std::string, armnn::TensorShape>& inputShapes, 83 const std::vector<std::string>& requestedOutputs) override; 84 85 /// Retrieves binding info (layer id and tensor info) for the network input identified by the given layer name. 86 virtual BindingPointInfo GetNetworkInputBindingInfo(const std::string& name) const override; 87 88 /// Retrieves binding info (layer id and tensor info) for the network output identified by the given layer name. 89 virtual BindingPointInfo GetNetworkOutputBindingInfo(const std::string& name) const override; 90 91 public: 92 TfParser(); 93 94 private: 95 template <typename T> 96 friend class ParsedConstTfOperation; 97 friend class ParsedMatMulTfOperation; 98 friend class ParsedMulTfOperation; 99 100 /// Parses a GraphDef loaded into memory from one of the other CreateNetwork*. 101 armnn::INetworkPtr CreateNetworkFromGraphDef(const tensorflow::GraphDef& graphDef, 102 const std::map<std::string, armnn::TensorShape>& inputShapes, 103 const std::vector<std::string>& requestedOutputs); 104 105 /// Sets up variables and then performs BFS to parse all nodes. 106 void LoadGraphDef(const tensorflow::GraphDef& graphDef); 107 108 /// Parses a given node, assuming nodes before it in the graph have been done. 109 void LoadNodeDef(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); 110 111 /// Handling identity layers as the input for Conv2D layer. 112 const tensorflow::NodeDef* ResolveIdentityNode(const tensorflow::NodeDef* nodeDef); 113 /// Finds the nodes connected as inputs of the given node in the graph. 114 std::vector<OutputOfConstNodeDef> GetTfInputNodes(const tensorflow::NodeDef& nodeDef) const; 115 /// Finds the IParsedTfOperations for the nodes connected as inputs of the given node in the graph, 116 /// and throws an exception if the number of inputs does not match the expected one. 117 /// This will automatically resolve any identity nodes. The result vector contains the parsed operation 118 /// together with the output tensor index to make the connection unambiguous. 119 std::vector<OutputOfParsedTfOperation> GetInputParsedTfOperationsChecked(const tensorflow::NodeDef& nodeDef, 120 std::size_t expectedNumInputs); 121 122 ParsedTfOperationPtr ParseConst(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); 123 124 /// Checks if there is a pre-parsed const tensor available with the given name and Type. 125 template<typename Type> 126 bool HasParsedConstTensor(const std::string & nodeName) const; 127 template<typename Type> 128 bool HasParsedConstTensor(ParsedTfOperation* parsedTfOpPtr) const; 129 130 unsigned int GetConstInputIndex(const std::vector<OutputOfParsedTfOperation>& inputs); 131 132 ParsedTfOperationPtr ParseAdd(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); 133 ParsedTfOperationPtr ParseAddN(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); 134 ParsedTfOperationPtr ParseBiasAdd(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); 135 ParsedTfOperationPtr ParseConv2D(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); 136 ParsedTfOperationPtr ParseDepthwiseConv2D(const tensorflow::NodeDef& nodeDef, 137 const tensorflow::GraphDef& graphDef); 138 ParsedTfOperationPtr ParseExpandDims(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); 139 ParsedTfOperationPtr ParseFusedBatchNorm(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); 140 ParsedTfOperationPtr ParseConcat(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); 141 ParsedTfOperationPtr ParseIdentity(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); 142 ParsedTfOperationPtr ParseLrn(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); 143 ParsedTfOperationPtr ParseMatMul(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); 144 ParsedTfOperationPtr ParseMean(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); 145 ParsedTfOperationPtr ParseMul(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); 146 ParsedTfOperationPtr ParsePlaceholder(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); 147 ParsedTfOperationPtr ParseRealDiv(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); 148 ParsedTfOperationPtr ParseRelu(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); 149 ParsedTfOperationPtr ParseRelu6(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); 150 ParsedTfOperationPtr ParseReshape(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); 151 ParsedTfOperationPtr ParseResizeBilinear(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); 152 ParsedTfOperationPtr ParseRsqrt(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); 153 ParsedTfOperationPtr ParseShape(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); 154 ParsedTfOperationPtr ParseSqueeze(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); 155 ParsedTfOperationPtr ParseSigmoid(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); 156 ParsedTfOperationPtr ParseSoftmax(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); 157 ParsedTfOperationPtr ParseSoftplus(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); 158 ParsedTfOperationPtr ParseSplit(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); 159 ParsedTfOperationPtr ParseStridedSlice(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); 160 ParsedTfOperationPtr ParseTanh(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); 161 ParsedTfOperationPtr ParseMaxPool(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); 162 ParsedTfOperationPtr ParseAvgPool(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); 163 ParsedTfOperationPtr ParsePooling2d(const tensorflow::NodeDef& nodeDef, 164 const tensorflow::GraphDef& graphDef, 165 armnn::PoolingAlgorithm pooltype); 166 ParsedTfOperationPtr ParseEqual(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); 167 ParsedTfOperationPtr ParseMaximum(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); 168 ParsedTfOperationPtr ParseMinimum(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); 169 ParsedTfOperationPtr ParseGather(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); 170 ParsedTfOperationPtr ParseGreater(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); 171 ParsedTfOperationPtr ParsePad(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); 172 ParsedTfOperationPtr ParseSub(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); 173 ParsedTfOperationPtr ParseStack(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); 174 ParsedTfOperationPtr ParseTranspose(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); 175 ParsedTfOperationPtr AddActivationLayer(const tensorflow::NodeDef& nodeDef, armnn::ActivationDescriptor& desc); 176 ParsedTfOperationPtr AddAdditionLayer(const tensorflow::NodeDef& nodeDef, bool isBiasAdd = false); 177 ParsedTfOperationPtr AddRealDivLayer(const tensorflow::NodeDef& nodeDef); 178 ParsedTfOperationPtr AddMaximumLayer(const tensorflow::NodeDef& nodeDef); 179 180 private: 181 armnn::IConnectableLayer* AddMultiplicationLayer(const tensorflow::NodeDef& nodeDef); 182 183 armnn::IConnectableLayer* AddFullyConnectedLayer(const tensorflow::NodeDef& matMulNodeDef, 184 const tensorflow::NodeDef* addNodeDef, const char* armnnLayerName); 185 186 bool IsSupportedLeakyReluPattern(const tensorflow::NodeDef& mulNodeDef, 187 size_t alphaLayerIndex, 188 const OutputOfParsedTfOperation& otherOp, 189 armnn::IOutputSlot** outputOfLeakyRelu, 190 armnn::ActivationDescriptor & desc); 191 192 std::pair<armnn::IOutputSlot*, armnn::IOutputSlot*> ProcessElementwiseInputSlots( 193 const tensorflow::NodeDef& nodeDef, const std::string& layerName); 194 195 ParsedTfOperationPtr ProcessComparisonLayer( 196 armnn::IOutputSlot* input0Slot, 197 armnn::IOutputSlot* input1Slot, 198 armnn::IConnectableLayer* const layer, 199 const tensorflow::NodeDef& nodeDef); 200 201 ParsedTfOperationPtr ProcessElementwiseLayer( 202 armnn::IOutputSlot* input0Slot, 203 armnn::IOutputSlot* input1Slot, 204 armnn::IConnectableLayer* const layer, 205 const tensorflow::NodeDef& nodeDef); 206 207 armnn::IConnectableLayer* CreateAdditionLayer( 208 const tensorflow::NodeDef& nodeDef, 209 armnn::IOutputSlot* input0Slot, 210 armnn::IOutputSlot* input1Slot, 211 const std::string& layerName); 212 213 armnn::IConnectableLayer* CreateAdditionLayer( 214 const tensorflow::NodeDef& nodeDef, 215 const OutputOfParsedTfOperation& opOne, 216 const OutputOfParsedTfOperation& opTwo, 217 unsigned int numberOfAddition); 218 219 armnn::IConnectableLayer* CreateAdditionLayer( 220 const tensorflow::NodeDef& nodeDef, 221 armnn::IConnectableLayer* layerOne, 222 armnn::IConnectableLayer* layerTwo, 223 unsigned int numberOfAddition, 224 unsigned long numberOfLayersToConnect, 225 bool isOdd); 226 227 armnn::IConnectableLayer* CreateAdditionLayer( 228 const tensorflow::NodeDef& nodeDef, 229 const OutputOfParsedTfOperation& op, 230 armnn::IConnectableLayer* layer); 231 232 static std::pair<armnn::LayerBindingId, armnn::TensorInfo> GetBindingInfo(const std::string& layerName, 233 const char* bindingPointDesc, 234 const std::unordered_map<std::string, BindingPointInfo>& nameToBindingInfo); 235 236 void TrackInputBinding(armnn::IConnectableLayer* layer, 237 armnn::LayerBindingId id, 238 const armnn::TensorInfo& tensorInfo); 239 240 void TrackOutputBinding(armnn::IConnectableLayer* layer, 241 armnn::LayerBindingId id, 242 const armnn::TensorInfo& tensorInfo); 243 244 static void TrackBindingPoint(armnn::IConnectableLayer* layer, armnn::LayerBindingId id, 245 const armnn::TensorInfo& tensorInfo, 246 const char* bindingPointDesc, 247 std::unordered_map<std::string, BindingPointInfo>& nameToBindingInfo); 248 249 void Cleanup(); 250 251 /// The network we're building. Gets cleared after it is passed to the user. 252 armnn::INetworkPtr m_Network; 253 254 using OperationParsingFunction = ParsedTfOperationPtr(TfParser::*)(const tensorflow::NodeDef& nodeDef, 255 const tensorflow::GraphDef& graphDef); 256 257 /// Map of TensorFlow operation names to parsing member functions. 258 static const std::map<std::string, OperationParsingFunction> ms_OperationNameToParsingFunctions; 259 260 static const std::list<std::string> m_ControlInputs; 261 262 std::map<std::string, armnn::TensorShape> m_InputShapes; 263 std::vector<std::string> m_RequestedOutputs; 264 265 /// Map of nodes extracted from the GraphDef to speed up parsing. 266 std::unordered_map<std::string, const tensorflow::NodeDef*> m_NodesByName; 267 268 std::unordered_map<std::string, ParsedTfOperationPtr> m_ParsedTfOperations; 269 270 /// Maps input layer names to their corresponding ids and tensor info. 271 std::unordered_map<std::string, BindingPointInfo> m_NetworkInputsBindingInfo; 272 273 /// Maps output layer names to their corresponding ids and tensor info. 274 std::unordered_map<std::string, BindingPointInfo> m_NetworkOutputsBindingInfo; 275 }; 276 277 } 278