1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 #pragma once 6 7 #include "armnn/Types.hpp" 8 #include "armnn/NetworkFwd.hpp" 9 #include "armnn/Tensor.hpp" 10 #include "armnn/INetwork.hpp" 11 #include "armnn/Optional.hpp" 12 13 #include <memory> 14 #include <map> 15 #include <vector> 16 17 namespace armnnTfLiteParser 18 { 19 20 using BindingPointInfo = armnn::BindingPointInfo; 21 22 class ITfLiteParser; 23 using ITfLiteParserPtr = std::unique_ptr<ITfLiteParser, void(*)(ITfLiteParser* parser)>; 24 25 class ITfLiteParser 26 { 27 public: 28 struct TfLiteParserOptions 29 { TfLiteParserOptionsarmnnTfLiteParser::ITfLiteParser::TfLiteParserOptions30 TfLiteParserOptions() 31 : m_StandInLayerForUnsupported(false), 32 m_InferAndValidate(false) {} 33 34 bool m_StandInLayerForUnsupported; 35 bool m_InferAndValidate; 36 }; 37 38 static ITfLiteParser* CreateRaw(const armnn::Optional<TfLiteParserOptions>& options = armnn::EmptyOptional()); 39 static ITfLiteParserPtr Create(const armnn::Optional<TfLiteParserOptions>& options = armnn::EmptyOptional()); 40 static void Destroy(ITfLiteParser* parser); 41 42 /// Create the network from a flatbuffers binary file on disk 43 virtual armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile) = 0; 44 45 /// Create the network from a flatbuffers binary 46 virtual armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t> & binaryContent) = 0; 47 48 /// Retrieve binding info (layer id and tensor info) for the network input identified by 49 /// the given layer name and subgraph id 50 virtual BindingPointInfo GetNetworkInputBindingInfo(size_t subgraphId, 51 const std::string& name) const = 0; 52 53 /// Retrieve binding info (layer id and tensor info) for the network output identified by 54 /// the given layer name and subgraph id 55 virtual BindingPointInfo GetNetworkOutputBindingInfo(size_t subgraphId, 56 const std::string& name) const = 0; 57 58 /// Return the number of subgraphs in the parsed model 59 virtual size_t GetSubgraphCount() const = 0; 60 61 /// Return the input tensor names for a given subgraph 62 virtual std::vector<std::string> GetSubgraphInputTensorNames(size_t subgraphId) const = 0; 63 64 /// Return the output tensor names for a given subgraph 65 virtual std::vector<std::string> GetSubgraphOutputTensorNames(size_t subgraphId) const = 0; 66 67 protected: ~ITfLiteParser()68 virtual ~ITfLiteParser() {}; 69 }; 70 71 } 72