• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "armnnTfParser/ITfParser.hpp"
8 
9 #include "armnn/Types.hpp"
10 #include "armnn/Tensor.hpp"
11 #include "armnn/INetwork.hpp"
12 
13 #include <list>
14 #include <map>
15 #include <memory>
16 #include <unordered_map>
17 #include <utility>
18 #include <vector>
19 
20 namespace armnn
21 {
22 class TensorInfo;
23 }
24 
25 namespace tensorflow
26 {
27 class GraphDef;
28 class NodeDef;
29 }
30 
31 namespace armnnTfParser
32 {
33 
34 class ParsedTfOperation;
35 using ParsedTfOperationPtr = std::unique_ptr<ParsedTfOperation>;
36 
37 ///
38 /// WithOutputTensorIndex wraps a value and an index. The purpose of
39 /// this template is to signify that, in Tensorflow, the input name of
40 /// a layer has the convention of 'inputTensorName:#index', where the
41 /// #index can be omitted and it implicitly means the 0 output of
42 /// the referenced layer. By supporting this notation we can handle
43 /// layers with multiple outputs, such as Split.
44 ///
45 template <typename T>
46 struct WithOutputTensorIndex
47 {
48     T                m_IndexedValue;
49     unsigned int     m_Index;
50 
WithOutputTensorIndexarmnnTfParser::WithOutputTensorIndex51     WithOutputTensorIndex(const T & value, unsigned int index)
52     : m_IndexedValue{value}
53     , m_Index{index} {}
54 
WithOutputTensorIndexarmnnTfParser::WithOutputTensorIndex55     WithOutputTensorIndex(T && value, unsigned int index)
56     : m_IndexedValue{value}
57     , m_Index{index} {}
58 };
59 
60 using OutputOfParsedTfOperation = WithOutputTensorIndex<ParsedTfOperation *>;
61 using OutputOfConstNodeDef = WithOutputTensorIndex<const tensorflow::NodeDef*>;
62 using OutputId = WithOutputTensorIndex<std::string>;
63 
64 class TfParser : public ITfParser
65 {
66 public:
67     /// Creates the network from a protobuf text file on the disk.
68     virtual armnn::INetworkPtr CreateNetworkFromTextFile(
69         const char* graphFile,
70         const std::map<std::string, armnn::TensorShape>& inputShapes,
71         const std::vector<std::string>& requestedOutputs) override;
72 
73     /// Creates the network from a protobuf binary file on the disk.
74     virtual armnn::INetworkPtr CreateNetworkFromBinaryFile(
75         const char* graphFile,
76         const std::map<std::string, armnn::TensorShape>& inputShapes,
77         const std::vector<std::string>& requestedOutputs) override;
78 
79     /// Creates the network directly from protobuf text in a string. Useful for debugging/testing.
80     virtual armnn::INetworkPtr CreateNetworkFromString(
81         const char* protoText,
82         const std::map<std::string, armnn::TensorShape>& inputShapes,
83         const std::vector<std::string>& requestedOutputs) override;
84 
85     /// Retrieves binding info (layer id and tensor info) for the network input identified by the given layer name.
86     virtual BindingPointInfo GetNetworkInputBindingInfo(const std::string& name) const override;
87 
88     /// Retrieves binding info (layer id and tensor info) for the network output identified by the given layer name.
89     virtual BindingPointInfo GetNetworkOutputBindingInfo(const std::string& name) const override;
90 
91 public:
92     TfParser();
93 
94 private:
95     template <typename T>
96     friend class ParsedConstTfOperation;
97     friend class ParsedMatMulTfOperation;
98     friend class ParsedMulTfOperation;
99 
100     /// Parses a GraphDef loaded into memory from one of the other CreateNetwork*.
101     armnn::INetworkPtr CreateNetworkFromGraphDef(const tensorflow::GraphDef& graphDef,
102         const std::map<std::string, armnn::TensorShape>& inputShapes,
103         const std::vector<std::string>& requestedOutputs);
104 
105     /// Sets up variables and then performs BFS to parse all nodes.
106     void LoadGraphDef(const tensorflow::GraphDef& graphDef);
107 
108     /// Parses a given node, assuming nodes before it in the graph have been done.
109     void LoadNodeDef(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
110 
111     /// Handling identity layers as the input for Conv2D layer.
112     const tensorflow::NodeDef* ResolveIdentityNode(const tensorflow::NodeDef* nodeDef);
113     /// Finds the nodes connected as inputs of the given node in the graph.
114     std::vector<OutputOfConstNodeDef> GetTfInputNodes(const tensorflow::NodeDef& nodeDef) const;
115     /// Finds the IParsedTfOperations for the nodes connected as inputs of the given node in the graph,
116     /// and throws an exception if the number of inputs does not match the expected one.
117     /// This will automatically resolve any identity nodes. The result vector contains the parsed operation
118     /// together with the output tensor index to make the connection unambiguous.
119     std::vector<OutputOfParsedTfOperation> GetInputParsedTfOperationsChecked(const tensorflow::NodeDef& nodeDef,
120                                                                              std::size_t expectedNumInputs);
121 
122     ParsedTfOperationPtr ParseConst(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
123 
124     /// Checks if there is a pre-parsed const tensor available with the given name and Type.
125     template<typename Type>
126     bool HasParsedConstTensor(const std::string & nodeName) const;
127     template<typename Type>
128     bool HasParsedConstTensor(ParsedTfOperation* parsedTfOpPtr) const;
129 
130     unsigned int GetConstInputIndex(const std::vector<OutputOfParsedTfOperation>& inputs);
131 
132     ParsedTfOperationPtr ParseAdd(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
133     ParsedTfOperationPtr ParseAddN(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
134     ParsedTfOperationPtr ParseBiasAdd(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
135     ParsedTfOperationPtr ParseConv2D(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
136     ParsedTfOperationPtr ParseDepthwiseConv2D(const tensorflow::NodeDef& nodeDef,
137                                               const tensorflow::GraphDef& graphDef);
138     ParsedTfOperationPtr ParseExpandDims(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
139     ParsedTfOperationPtr ParseFusedBatchNorm(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
140     ParsedTfOperationPtr ParseConcat(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
141     ParsedTfOperationPtr ParseIdentity(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
142     ParsedTfOperationPtr ParseLrn(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
143     ParsedTfOperationPtr ParseMatMul(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
144     ParsedTfOperationPtr ParseMean(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
145     ParsedTfOperationPtr ParseMul(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
146     ParsedTfOperationPtr ParsePlaceholder(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
147     ParsedTfOperationPtr ParseRealDiv(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
148     ParsedTfOperationPtr ParseRelu(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
149     ParsedTfOperationPtr ParseRelu6(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
150     ParsedTfOperationPtr ParseReshape(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
151     ParsedTfOperationPtr ParseResizeBilinear(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
152     ParsedTfOperationPtr ParseRsqrt(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
153     ParsedTfOperationPtr ParseShape(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
154     ParsedTfOperationPtr ParseSqueeze(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
155     ParsedTfOperationPtr ParseSigmoid(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
156     ParsedTfOperationPtr ParseSoftmax(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
157     ParsedTfOperationPtr ParseSoftplus(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
158     ParsedTfOperationPtr ParseSplit(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
159     ParsedTfOperationPtr ParseStridedSlice(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
160     ParsedTfOperationPtr ParseTanh(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
161     ParsedTfOperationPtr ParseMaxPool(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
162     ParsedTfOperationPtr ParseAvgPool(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
163     ParsedTfOperationPtr ParsePooling2d(const tensorflow::NodeDef& nodeDef,
164                                         const tensorflow::GraphDef& graphDef,
165                                         armnn::PoolingAlgorithm pooltype);
166     ParsedTfOperationPtr ParseEqual(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
167     ParsedTfOperationPtr ParseMaximum(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
168     ParsedTfOperationPtr ParseMinimum(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
169     ParsedTfOperationPtr ParseGather(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
170     ParsedTfOperationPtr ParseGreater(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
171     ParsedTfOperationPtr ParsePad(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
172     ParsedTfOperationPtr ParseSub(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
173     ParsedTfOperationPtr ParseStack(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
174     ParsedTfOperationPtr ParseTranspose(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
175     ParsedTfOperationPtr AddActivationLayer(const tensorflow::NodeDef& nodeDef, armnn::ActivationDescriptor& desc);
176     ParsedTfOperationPtr AddAdditionLayer(const tensorflow::NodeDef& nodeDef, bool isBiasAdd = false);
177     ParsedTfOperationPtr AddRealDivLayer(const tensorflow::NodeDef& nodeDef);
178     ParsedTfOperationPtr AddMaximumLayer(const tensorflow::NodeDef& nodeDef);
179 
180 private:
181     armnn::IConnectableLayer* AddMultiplicationLayer(const tensorflow::NodeDef& nodeDef);
182 
183     armnn::IConnectableLayer* AddFullyConnectedLayer(const tensorflow::NodeDef& matMulNodeDef,
184         const tensorflow::NodeDef* addNodeDef, const char* armnnLayerName);
185 
186     bool IsSupportedLeakyReluPattern(const tensorflow::NodeDef& mulNodeDef,
187                                     size_t alphaLayerIndex,
188                                     const OutputOfParsedTfOperation& otherOp,
189                                     armnn::IOutputSlot** outputOfLeakyRelu,
190                                     armnn::ActivationDescriptor & desc);
191 
192     std::pair<armnn::IOutputSlot*, armnn::IOutputSlot*> ProcessElementwiseInputSlots(
193             const tensorflow::NodeDef& nodeDef, const std::string& layerName);
194 
195     ParsedTfOperationPtr ProcessComparisonLayer(
196         armnn::IOutputSlot* input0Slot,
197         armnn::IOutputSlot* input1Slot,
198         armnn::IConnectableLayer* const layer,
199         const tensorflow::NodeDef& nodeDef);
200 
201     ParsedTfOperationPtr ProcessElementwiseLayer(
202             armnn::IOutputSlot* input0Slot,
203             armnn::IOutputSlot* input1Slot,
204             armnn::IConnectableLayer* const layer,
205             const tensorflow::NodeDef& nodeDef);
206 
207     armnn::IConnectableLayer* CreateAdditionLayer(
208             const tensorflow::NodeDef& nodeDef,
209             armnn::IOutputSlot* input0Slot,
210             armnn::IOutputSlot* input1Slot,
211             const std::string& layerName);
212 
213     armnn::IConnectableLayer* CreateAdditionLayer(
214             const tensorflow::NodeDef& nodeDef,
215             const OutputOfParsedTfOperation& opOne,
216             const OutputOfParsedTfOperation& opTwo,
217             unsigned int numberOfAddition);
218 
219     armnn::IConnectableLayer* CreateAdditionLayer(
220             const tensorflow::NodeDef& nodeDef,
221             armnn::IConnectableLayer* layerOne,
222             armnn::IConnectableLayer* layerTwo,
223             unsigned int numberOfAddition,
224             unsigned long numberOfLayersToConnect,
225             bool isOdd);
226 
227     armnn::IConnectableLayer* CreateAdditionLayer(
228             const tensorflow::NodeDef& nodeDef,
229             const OutputOfParsedTfOperation& op,
230             armnn::IConnectableLayer* layer);
231 
232     static std::pair<armnn::LayerBindingId, armnn::TensorInfo> GetBindingInfo(const std::string& layerName,
233         const char* bindingPointDesc,
234         const std::unordered_map<std::string, BindingPointInfo>& nameToBindingInfo);
235 
236     void TrackInputBinding(armnn::IConnectableLayer* layer,
237         armnn::LayerBindingId id,
238         const armnn::TensorInfo& tensorInfo);
239 
240     void TrackOutputBinding(armnn::IConnectableLayer* layer,
241         armnn::LayerBindingId id,
242         const armnn::TensorInfo& tensorInfo);
243 
244     static void TrackBindingPoint(armnn::IConnectableLayer* layer, armnn::LayerBindingId id,
245         const armnn::TensorInfo& tensorInfo,
246         const char* bindingPointDesc,
247         std::unordered_map<std::string, BindingPointInfo>& nameToBindingInfo);
248 
249     void Cleanup();
250 
251     /// The network we're building. Gets cleared after it is passed to the user.
252     armnn::INetworkPtr m_Network;
253 
254     using OperationParsingFunction = ParsedTfOperationPtr(TfParser::*)(const tensorflow::NodeDef& nodeDef,
255                                                                  const tensorflow::GraphDef& graphDef);
256 
257     /// Map of TensorFlow operation names to parsing member functions.
258     static const std::map<std::string, OperationParsingFunction> ms_OperationNameToParsingFunctions;
259 
260     static const std::list<std::string> m_ControlInputs;
261 
262     std::map<std::string, armnn::TensorShape> m_InputShapes;
263     std::vector<std::string> m_RequestedOutputs;
264 
265     /// Map of nodes extracted from the GraphDef to speed up parsing.
266     std::unordered_map<std::string, const tensorflow::NodeDef*> m_NodesByName;
267 
268     std::unordered_map<std::string, ParsedTfOperationPtr> m_ParsedTfOperations;
269 
270     /// Maps input layer names to their corresponding ids and tensor info.
271     std::unordered_map<std::string, BindingPointInfo> m_NetworkInputsBindingInfo;
272 
273     /// Maps output layer names to their corresponding ids and tensor info.
274     std::unordered_map<std::string, BindingPointInfo> m_NetworkOutputsBindingInfo;
275 };
276 
277 }
278