• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include <armnn/Tensor.hpp>
8 #include <armnn/Types.hpp>
9 
10 #include "Network.hpp"
11 #include "LayerFwd.hpp"
12 #include "Profiling.hpp"
13 
14 #include <armnn/backends/IBackendInternal.hpp>
15 #include <backendsCommon/TensorHandleFactoryRegistry.hpp>
16 #include <backendsCommon/Workload.hpp>
17 #include <backendsCommon/WorkloadFactory.hpp>
18 #include <ProfilingService.hpp>
19 #include <TimelineUtilityMethods.hpp>
20 
21 #include <mutex>
22 #include <unordered_map>
23 
24 namespace cl
25 {
26     class Context;
27     class CommandQueue;
28     class Device;
29 }
30 
31 namespace armnn
32 {
33 
34 class LoadedNetwork
35 {
36 public:
37     using WorkloadQueue = std::vector< std::unique_ptr<IWorkload> >;
~LoadedNetwork()38     ~LoadedNetwork(){ FreeWorkingMemory(); }
39 
40     TensorInfo GetInputTensorInfo(LayerBindingId layerId) const;
41     TensorInfo GetOutputTensorInfo(LayerBindingId layerId) const;
42 
43     Status EnqueueWorkload(const InputTensors& inputTensors, const OutputTensors& outputTensors);
44 
45     static std::unique_ptr<LoadedNetwork> MakeLoadedNetwork(std::unique_ptr<OptimizedNetwork> net,
46                                                             std::string & errorMessage,
47                                                             const INetworkProperties& networkProperties,
48                                                             profiling::ProfilingService& profilingService);
49 
50     // NOTE we return by reference as the purpose of this method is only to provide
51     // access to the private m_Profiler and in theory we should not need to increment
52     // the shared_ptr's reference counter
GetProfiler() const53     const std::shared_ptr<Profiler>& GetProfiler() const { return m_Profiler; }
54 
55     void FreeWorkingMemory();
56 
57     void RegisterDebugCallback(const DebugCallbackFunction& func);
58 
59     void SendNetworkStructure();
60 
61     profiling::ProfilingGuid GetNetworkGuid();
62 
63 private:
64     void AllocateWorkingMemory(std::lock_guard<std::mutex>& lock);
65 
66     LoadedNetwork(std::unique_ptr<OptimizedNetwork> net,
67                   const INetworkProperties& networkProperties,
68                   profiling::ProfilingService& profilingService);
69 
70     void EnqueueInput(const BindableLayer& layer, ITensorHandle* tensorHandle, const TensorInfo& tensorInfo);
71 
72     void EnqueueOutput(const BindableLayer& layer, ITensorHandle* tensorHandle, const TensorInfo& tensorInfo);
73 
74     bool Execute(std::unique_ptr<profiling::TimelineUtilityMethods>& timelineUtils,
75                  profiling::ProfilingGuid inferenceGuid);
76 
77 
78     const IWorkloadFactory& GetWorkloadFactory(const Layer& layer) const;
79 
80     using BackendPtrMap = std::unordered_map<BackendId, IBackendInternalUniquePtr>;
81 
82     using WorkloadFactoryWithMemoryManager =
83         std::pair<IBackendInternal::IWorkloadFactoryPtr, IBackendInternal::IMemoryManagerSharedPtr>;
84 
85     using WorkloadFactoryMap = std::unordered_map<BackendId, WorkloadFactoryWithMemoryManager>;
86 
87     BackendPtrMap       m_Backends;
88     WorkloadFactoryMap  m_WorkloadFactories;
89 
90     std::unique_ptr<OptimizedNetwork> m_OptimizedNetwork;
91     WorkloadQueue m_InputQueue;
92     WorkloadQueue m_WorkloadQueue;
93     WorkloadQueue m_OutputQueue;
94     std::shared_ptr<Profiler> m_Profiler;
95 
96     mutable std::mutex m_WorkingMemMutex;
97 
98     bool m_IsWorkingMemAllocated=false;
99     bool m_IsImportEnabled=false;
100     bool m_IsExportEnabled=false;
101 
102     TensorHandleFactoryRegistry m_TensorHandleFactoryRegistry;
103 
104     profiling::ProfilingService&  m_ProfilingService;
105 };
106 
107 }
108