• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "armnn/Types.hpp"
8 #include "armnn/Tensor.hpp"
9 #include "armnn/INetwork.hpp"
10 
11 #include <map>
12 #include <memory>
13 #include <unordered_map>
14 #include <vector>
15 
16 namespace armnnTfParser
17 {
18 
19 using BindingPointInfo = armnn::BindingPointInfo;
20 
21 class ITfParser;
22 using ITfParserPtr = std::unique_ptr<ITfParser, void(*)(ITfParser* parser)>;
23 
24 /// Parses a directed acyclic graph from a tensorflow protobuf file.
25 class ITfParser
26 {
27 public:
28     static ITfParser* CreateRaw();
29     static ITfParserPtr Create();
30     static void Destroy(ITfParser* parser);
31 
32     /// Create the network from a protobuf text file on the disk.
33     virtual armnn::INetworkPtr CreateNetworkFromTextFile(
34         const char* graphFile,
35         const std::map<std::string, armnn::TensorShape>& inputShapes,
36         const std::vector<std::string>& requestedOutputs) = 0;
37 
38     /// Create the network from a protobuf binary file on the disk.
39     virtual armnn::INetworkPtr CreateNetworkFromBinaryFile(
40         const char* graphFile,
41         const std::map<std::string, armnn::TensorShape>& inputShapes,
42         const std::vector<std::string>& requestedOutputs) = 0;
43 
44     /// Create the network directly from protobuf text in a string. Useful for debugging/testing.
45     virtual armnn::INetworkPtr CreateNetworkFromString(
46         const char* protoText,
47         const std::map<std::string, armnn::TensorShape>& inputShapes,
48         const std::vector<std::string>& requestedOutputs) = 0;
49 
50     /// Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name.
51     virtual BindingPointInfo GetNetworkInputBindingInfo(const std::string& name) const = 0;
52 
53     /// Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name.
54     virtual BindingPointInfo GetNetworkOutputBindingInfo(const std::string& name) const = 0;
55 
56 protected:
~ITfParser()57     virtual ~ITfParser() {};
58 };
59 
60 }
61