1 // 2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 #pragma once 6 7 #include "LoadedNetwork.hpp" 8 #include "DeviceSpec.hpp" 9 10 #include <armnn/INetwork.hpp> 11 #include <armnn/IRuntime.hpp> 12 #include <armnn/Tensor.hpp> 13 #include <armnn/BackendId.hpp> 14 15 #include <armnn/backends/DynamicBackend.hpp> 16 17 #include <ProfilingService.hpp> 18 19 #include <IProfilingService.hpp> 20 #include <IReportStructure.hpp> 21 22 #include <mutex> 23 #include <unordered_map> 24 25 namespace armnn 26 { 27 using LoadedNetworks = std::unordered_map<NetworkId, std::unique_ptr<LoadedNetwork>>; 28 using IReportStructure = profiling::IReportStructure; 29 30 class Runtime final : public IRuntime, 31 public IReportStructure 32 { 33 public: 34 /// Loads a complete network into the Runtime. 35 /// @param [out] networkIdOut - Unique identifier for the network is returned in this reference. 36 /// @param [in] network - Complete network to load into the Runtime. 37 /// The runtime takes ownership of the network once passed in. 38 /// @return armnn::Status 39 virtual Status LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr network) override; 40 41 /// Load a complete network into the IRuntime. 42 /// @param [out] networkIdOut Unique identifier for the network is returned in this reference. 43 /// @param [in] network Complete network to load into the IRuntime. 44 /// @param [out] errorMessage Error message if there were any errors. 45 /// The runtime takes ownership of the network once passed in. 46 /// @return armnn::Status 47 virtual Status LoadNetwork(NetworkId& networkIdOut, 48 IOptimizedNetworkPtr network, 49 std::string& errorMessage) override; 50 51 virtual Status LoadNetwork(NetworkId& networkIdOut, 52 IOptimizedNetworkPtr network, 53 std::string& errorMessage, 54 const INetworkProperties& networkProperties) override; 55 56 virtual TensorInfo GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const override; 57 virtual TensorInfo GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const override; 58 59 // Evaluates network using input in inputTensors, outputs filled into outputTensors. 60 virtual Status EnqueueWorkload(NetworkId networkId, 61 const InputTensors& inputTensors, 62 const OutputTensors& outputTensors) override; 63 64 /// Unloads a network from the Runtime. 65 /// At the moment this only removes the network from the m_Impl->m_Network. 66 /// This might need more work in the future to be AndroidNN compliant. 67 /// @param [in] networkId Unique identifier for the network to be unloaded. Generated in LoadNetwork(). 68 /// @return armnn::Status 69 virtual Status UnloadNetwork(NetworkId networkId) override; 70 GetDeviceSpec() const71 virtual const IDeviceSpec& GetDeviceSpec() const override { return m_DeviceSpec; } 72 73 /// Gets the profiler corresponding to the given network id. 74 /// @param networkId The id of the network for which to get the profile. 75 /// @return A pointer to the requested profiler, or nullptr if not found. 76 virtual const std::shared_ptr<IProfiler> GetProfiler(NetworkId networkId) const override; 77 78 /// Registers a callback function to debug layers performing custom computations on intermediate tensors. 79 /// @param networkId The id of the network to register the callback. 80 /// @param func callback function to pass to the debug layer. 81 virtual void RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func) override; 82 83 /// Creates a runtime for workload execution. 84 Runtime(const CreationOptions& options); 85 86 ~Runtime(); 87 88 //NOTE: we won't need the profiling service reference but it is good to pass the service 89 // in this way to facilitate other implementations down the road 90 virtual void ReportStructure() override; 91 92 private: 93 friend void RuntimeLoadedNetworksReserve(armnn::Runtime* runtime); // See RuntimeTests.cpp 94 95 friend profiling::ProfilingService& GetProfilingService(armnn::Runtime* runtime); // See RuntimeTests.cpp 96 97 int GenerateNetworkId(); 98 99 LoadedNetwork* GetLoadedNetworkPtr(NetworkId networkId) const; 100 101 template<typename Func> LoadedNetworkFuncSafe(NetworkId networkId,Func f)102 void LoadedNetworkFuncSafe(NetworkId networkId, Func f) 103 { 104 std::lock_guard<std::mutex> lockGuard(m_Mutex); 105 auto iter = m_LoadedNetworks.find(networkId); 106 if (iter != m_LoadedNetworks.end()) 107 { 108 f(iter->second.get()); 109 } 110 } 111 112 /// Loads any available/compatible dynamic backend in the runtime. 113 void LoadDynamicBackends(const std::string& overrideBackendPath); 114 115 mutable std::mutex m_Mutex; 116 117 /// Map of Loaded Networks with associated GUID as key 118 LoadedNetworks m_LoadedNetworks; 119 120 std::unordered_map<BackendId, IBackendInternal::IBackendContextPtr> m_BackendContexts; 121 122 int m_NetworkIdCounter; 123 124 DeviceSpec m_DeviceSpec; 125 126 /// List of dynamic backends loaded in the runtime 127 std::vector<DynamicBackendPtr> m_DynamicBackends; 128 129 /// Profiling Service Instance 130 profiling::ProfilingService m_ProfilingService; 131 }; 132 133 } // namespace armnn 134