• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "armnn/Types.hpp"
8 #include "armnn/NetworkFwd.hpp"
9 #include "armnn/Tensor.hpp"
10 #include "armnn/INetwork.hpp"
11 #include "armnn/Optional.hpp"
12 
13 #include <memory>
14 #include <map>
15 #include <vector>
16 
17 namespace armnnTfLiteParser
18 {
19 
20 using BindingPointInfo = armnn::BindingPointInfo;
21 
22 class ITfLiteParser;
23 using ITfLiteParserPtr = std::unique_ptr<ITfLiteParser, void(*)(ITfLiteParser* parser)>;
24 
25 class ITfLiteParser
26 {
27 public:
28     struct TfLiteParserOptions
29     {
TfLiteParserOptionsarmnnTfLiteParser::ITfLiteParser::TfLiteParserOptions30         TfLiteParserOptions()
31             : m_StandInLayerForUnsupported(false),
32               m_InferAndValidate(false) {}
33 
34         bool m_StandInLayerForUnsupported;
35         bool m_InferAndValidate;
36     };
37 
38     static ITfLiteParser* CreateRaw(const armnn::Optional<TfLiteParserOptions>& options = armnn::EmptyOptional());
39     static ITfLiteParserPtr Create(const armnn::Optional<TfLiteParserOptions>& options = armnn::EmptyOptional());
40     static void Destroy(ITfLiteParser* parser);
41 
42     /// Create the network from a flatbuffers binary file on disk
43     virtual armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile) = 0;
44 
45     /// Create the network from a flatbuffers binary
46     virtual armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t> & binaryContent) = 0;
47 
48     /// Retrieve binding info (layer id and tensor info) for the network input identified by
49     /// the given layer name and subgraph id
50     virtual BindingPointInfo GetNetworkInputBindingInfo(size_t subgraphId,
51                                                         const std::string& name) const = 0;
52 
53     /// Retrieve binding info (layer id and tensor info) for the network output identified by
54     /// the given layer name and subgraph id
55     virtual BindingPointInfo GetNetworkOutputBindingInfo(size_t subgraphId,
56                                                          const std::string& name) const = 0;
57 
58     /// Return the number of subgraphs in the parsed model
59     virtual size_t GetSubgraphCount() const = 0;
60 
61     /// Return the input tensor names for a given subgraph
62     virtual std::vector<std::string> GetSubgraphInputTensorNames(size_t subgraphId) const = 0;
63 
64     /// Return the output tensor names for a given subgraph
65     virtual std::vector<std::string> GetSubgraphOutputTensorNames(size_t subgraphId) const = 0;
66 
67 protected:
~ITfLiteParser()68     virtual ~ITfLiteParser() {};
69 };
70 
71 }
72