• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "LoadedNetwork.hpp"
8 #include "DeviceSpec.hpp"
9 
10 #include <armnn/INetwork.hpp>
11 #include <armnn/IRuntime.hpp>
12 #include <armnn/Tensor.hpp>
13 #include <armnn/BackendId.hpp>
14 
15 #include <armnn/backends/DynamicBackend.hpp>
16 
17 #include <ProfilingService.hpp>
18 
19 #include <IProfilingService.hpp>
20 #include <IReportStructure.hpp>
21 
22 #include <mutex>
23 #include <unordered_map>
24 
25 namespace armnn
26 {
27 using LoadedNetworks = std::unordered_map<NetworkId, std::unique_ptr<LoadedNetwork>>;
28 using IReportStructure = profiling::IReportStructure;
29 
30 class Runtime final :  public IRuntime,
31                        public IReportStructure
32 {
33 public:
34     /// Loads a complete network into the Runtime.
35     /// @param [out] networkIdOut - Unique identifier for the network is returned in this reference.
36     /// @param [in] network - Complete network to load into the Runtime.
37     /// The runtime takes ownership of the network once passed in.
38     /// @return armnn::Status
39     virtual Status LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr network) override;
40 
41     /// Load a complete network into the IRuntime.
42     /// @param [out] networkIdOut Unique identifier for the network is returned in this reference.
43     /// @param [in] network Complete network to load into the IRuntime.
44     /// @param [out] errorMessage Error message if there were any errors.
45     /// The runtime takes ownership of the network once passed in.
46     /// @return armnn::Status
47     virtual Status LoadNetwork(NetworkId& networkIdOut,
48                                IOptimizedNetworkPtr network,
49                                std::string& errorMessage) override;
50 
51     virtual Status LoadNetwork(NetworkId& networkIdOut,
52                                IOptimizedNetworkPtr network,
53                                std::string& errorMessage,
54                                const INetworkProperties& networkProperties) override;
55 
56     virtual TensorInfo GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const override;
57     virtual TensorInfo GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const override;
58 
59     // Evaluates network using input in inputTensors, outputs filled into outputTensors.
60     virtual Status EnqueueWorkload(NetworkId networkId,
61         const InputTensors& inputTensors,
62         const OutputTensors& outputTensors) override;
63 
64     /// Unloads a network from the Runtime.
65     /// At the moment this only removes the network from the m_Impl->m_Network.
66     /// This might need more work in the future to be AndroidNN compliant.
67     /// @param [in] networkId Unique identifier for the network to be unloaded. Generated in LoadNetwork().
68     /// @return armnn::Status
69     virtual Status UnloadNetwork(NetworkId networkId) override;
70 
GetDeviceSpec() const71     virtual const IDeviceSpec& GetDeviceSpec() const override { return m_DeviceSpec; }
72 
73     /// Gets the profiler corresponding to the given network id.
74     /// @param networkId The id of the network for which to get the profile.
75     /// @return A pointer to the requested profiler, or nullptr if not found.
76     virtual const std::shared_ptr<IProfiler> GetProfiler(NetworkId networkId) const override;
77 
78     /// Registers a callback function to debug layers performing custom computations on intermediate tensors.
79     /// @param networkId The id of the network to register the callback.
80     /// @param func callback function to pass to the debug layer.
81     virtual void RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func) override;
82 
83     /// Creates a runtime for workload execution.
84     Runtime(const CreationOptions& options);
85 
86     ~Runtime();
87 
88     //NOTE: we won't need the profiling service reference but it is good to pass the service
89     // in this way to facilitate other implementations down the road
90     virtual void ReportStructure() override;
91 
92 private:
93     friend void RuntimeLoadedNetworksReserve(armnn::Runtime* runtime); // See RuntimeTests.cpp
94 
95     friend profiling::ProfilingService& GetProfilingService(armnn::Runtime* runtime); // See RuntimeTests.cpp
96 
97     int GenerateNetworkId();
98 
99     LoadedNetwork* GetLoadedNetworkPtr(NetworkId networkId) const;
100 
101     template<typename Func>
LoadedNetworkFuncSafe(NetworkId networkId,Func f)102     void LoadedNetworkFuncSafe(NetworkId networkId, Func f)
103     {
104         std::lock_guard<std::mutex> lockGuard(m_Mutex);
105         auto iter = m_LoadedNetworks.find(networkId);
106         if (iter != m_LoadedNetworks.end())
107         {
108             f(iter->second.get());
109         }
110     }
111 
112     /// Loads any available/compatible dynamic backend in the runtime.
113     void LoadDynamicBackends(const std::string& overrideBackendPath);
114 
115     mutable std::mutex m_Mutex;
116 
117     /// Map of Loaded Networks with associated GUID as key
118     LoadedNetworks m_LoadedNetworks;
119 
120     std::unordered_map<BackendId, IBackendInternal::IBackendContextPtr> m_BackendContexts;
121 
122     int m_NetworkIdCounter;
123 
124     DeviceSpec m_DeviceSpec;
125 
126     /// List of dynamic backends loaded in the runtime
127     std::vector<DynamicBackendPtr> m_DynamicBackends;
128 
129     /// Profiling Service Instance
130     profiling::ProfilingService m_ProfilingService;
131 };
132 
133 } // namespace armnn
134