• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2019 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "ActivateTimelineReportingCommandHandler.hpp"
9 #include "BufferManager.hpp"
10 #include "CommandHandler.hpp"
11 #include "ConnectionAcknowledgedCommandHandler.hpp"
12 #include "CounterDirectory.hpp"
13 #include "CounterIdMap.hpp"
14 #include "DeactivateTimelineReportingCommandHandler.hpp"
15 #include "ICounterRegistry.hpp"
16 #include "ICounterValues.hpp"
17 #include <armnn/profiling/ILocalPacketHandler.hpp>
18 #include "IProfilingService.hpp"
19 #include "IReportStructure.hpp"
20 #include "PeriodicCounterCapture.hpp"
21 #include "PeriodicCounterSelectionCommandHandler.hpp"
22 #include "PerJobCounterSelectionCommandHandler.hpp"
23 #include "ProfilingConnectionFactory.hpp"
24 #include "ProfilingGuidGenerator.hpp"
25 #include "ProfilingStateMachine.hpp"
26 #include "RequestCounterDirectoryCommandHandler.hpp"
27 #include "SendCounterPacket.hpp"
28 #include "SendThread.hpp"
29 #include "SendTimelinePacket.hpp"
30 #include "TimelinePacketWriterFactory.hpp"
31 #include "INotifyBackends.hpp"
32 #include <armnn/backends/profiling/IBackendProfilingContext.hpp>
33 
34 #include <list>
35 
36 namespace armnn
37 {
38 
39 namespace profiling
40 {
41 // Static constants describing ArmNN's counter UID's
42 static const uint16_t NETWORK_LOADS         = 0;
43 static const uint16_t NETWORK_UNLOADS       = 1;
44 static const uint16_t REGISTERED_BACKENDS   = 2;
45 static const uint16_t UNREGISTERED_BACKENDS = 3;
46 static const uint16_t INFERENCES_RUN        = 4;
47 static const uint16_t MAX_ARMNN_COUNTER     = INFERENCES_RUN;
48 
49 class ProfilingService : public IReadWriteCounterValues, public IProfilingService, public INotifyBackends
50 {
51 public:
52     using ExternalProfilingOptions = IRuntime::CreationOptions::ExternalProfilingOptions;
53     using IProfilingConnectionFactoryPtr = std::unique_ptr<IProfilingConnectionFactory>;
54     using IProfilingConnectionPtr = std::unique_ptr<IProfilingConnection>;
55     using CounterIndices = std::vector<std::atomic<uint32_t>*>;
56     using CounterValues = std::list<std::atomic<uint32_t>>;
57     using BackendProfilingContext = std::unordered_map<BackendId,
58                                                        std::shared_ptr<armnn::profiling::IBackendProfilingContext>>;
59 
ProfilingService(Optional<IReportStructure &> reportStructure=EmptyOptional ())60     ProfilingService(Optional<IReportStructure&> reportStructure = EmptyOptional())
61         : m_Options()
62         , m_TimelineReporting(false)
63         , m_CounterDirectory()
64         , m_ProfilingConnectionFactory(new ProfilingConnectionFactory())
65         , m_ProfilingConnection()
66         , m_StateMachine()
67         , m_CounterIndex()
68         , m_CounterValues()
69         , m_CommandHandlerRegistry()
70         , m_PacketVersionResolver()
71         , m_CommandHandler(1000,
72                            false,
73                            m_CommandHandlerRegistry,
74                            m_PacketVersionResolver)
75         , m_BufferManager()
76         , m_SendCounterPacket(m_BufferManager)
77         , m_SendThread(m_StateMachine, m_BufferManager, m_SendCounterPacket)
78         , m_SendTimelinePacket(m_BufferManager)
79         , m_PeriodicCounterCapture(m_Holder, m_SendCounterPacket, *this, m_CounterIdMap, m_BackendProfilingContexts)
80         , m_ConnectionAcknowledgedCommandHandler(0,
81                                                  1,
82                                                  m_PacketVersionResolver.ResolvePacketVersion(0, 1).GetEncodedValue(),
83                                                  m_CounterDirectory,
84                                                  m_SendCounterPacket,
85                                                  m_SendTimelinePacket,
86                                                  m_StateMachine,
87                                                  *this,
88                                                  m_BackendProfilingContexts)
89         , m_RequestCounterDirectoryCommandHandler(0,
90                                                   3,
91                                                   m_PacketVersionResolver.ResolvePacketVersion(0, 3).GetEncodedValue(),
92                                                   m_CounterDirectory,
93                                                   m_SendCounterPacket,
94                                                   m_SendTimelinePacket,
95                                                   m_StateMachine)
96         , m_PeriodicCounterSelectionCommandHandler(0,
97                                                    4,
98                                                    m_PacketVersionResolver.ResolvePacketVersion(0, 4).GetEncodedValue(),
99                                                    m_BackendProfilingContexts,
100                                                    m_CounterIdMap,
101                                                    m_Holder,
102                                                    MAX_ARMNN_COUNTER,
103                                                    m_PeriodicCounterCapture,
104                                                    *this,
105                                                    m_SendCounterPacket,
106                                                    m_StateMachine)
107         , m_PerJobCounterSelectionCommandHandler(0,
108                                                  5,
109                                                  m_PacketVersionResolver.ResolvePacketVersion(0, 5).GetEncodedValue(),
110                                                  m_StateMachine)
111         , m_ActivateTimelineReportingCommandHandler(0,
112                                                     6,
113                                                     m_PacketVersionResolver.ResolvePacketVersion(0, 6)
114                                                                            .GetEncodedValue(),
115                                                     m_SendTimelinePacket,
116                                                     m_StateMachine,
117                                                     reportStructure,
118                                                     m_TimelineReporting,
119                                                     *this)
120         , m_DeactivateTimelineReportingCommandHandler(0,
121                                                       7,
122                                                       m_PacketVersionResolver.ResolvePacketVersion(0, 7)
123                                                                              .GetEncodedValue(),
124                                                       m_TimelineReporting,
125                                                       m_StateMachine,
126                                                       *this)
127         , m_TimelinePacketWriterFactory(m_BufferManager)
128         , m_MaxGlobalCounterId(armnn::profiling::INFERENCES_RUN)
129         , m_ServiceActive(false)
130     {
131         // Register the "Connection Acknowledged" command handler
132         m_CommandHandlerRegistry.RegisterFunctor(&m_ConnectionAcknowledgedCommandHandler);
133 
134         // Register the "Request Counter Directory" command handler
135         m_CommandHandlerRegistry.RegisterFunctor(&m_RequestCounterDirectoryCommandHandler);
136 
137         // Register the "Periodic Counter Selection" command handler
138         m_CommandHandlerRegistry.RegisterFunctor(&m_PeriodicCounterSelectionCommandHandler);
139 
140         // Register the "Per-Job Counter Selection" command handler
141         m_CommandHandlerRegistry.RegisterFunctor(&m_PerJobCounterSelectionCommandHandler);
142 
143         m_CommandHandlerRegistry.RegisterFunctor(&m_ActivateTimelineReportingCommandHandler);
144 
145         m_CommandHandlerRegistry.RegisterFunctor(&m_DeactivateTimelineReportingCommandHandler);
146     }
147 
148     ~ProfilingService();
149 
150     // Resets the profiling options, optionally clears the profiling service entirely
151     void ResetExternalProfilingOptions(const ExternalProfilingOptions& options, bool resetProfilingService = false);
152     ProfilingState ConfigureProfilingService(const ExternalProfilingOptions& options,
153                                              bool resetProfilingService = false);
154 
155 
156     // Updates the profiling service, making it transition to a new state if necessary
157     void Update();
158 
159     // Disconnects the profiling service from the external server
160     void Disconnect();
161 
162     // Store a profiling context returned from a backend that support profiling.
163     void AddBackendProfilingContext(const BackendId backendId,
164         std::shared_ptr<armnn::profiling::IBackendProfilingContext> profilingContext);
165 
166     // Enable the recording of timeline events and entities
167     void NotifyBackendsForTimelineReporting() override;
168 
169     const ICounterDirectory& GetCounterDirectory() const;
170     ICounterRegistry& GetCounterRegistry();
171     ProfilingState GetCurrentState() const;
172     bool IsCounterRegistered(uint16_t counterUid) const override;
173     uint32_t GetAbsoluteCounterValue(uint16_t counterUid) const override;
174     uint32_t GetDeltaCounterValue(uint16_t counterUid) override;
175     uint16_t GetCounterCount() const override;
176     // counter global/backend mapping functions
177     const ICounterMappings& GetCounterMappings() const override;
178     IRegisterCounterMapping& GetCounterMappingRegistry();
179 
180     // Getters for the profiling service state
181     bool IsProfilingEnabled() const override;
182 
183     CaptureData GetCaptureData() override;
184     void SetCaptureData(uint32_t capturePeriod,
185                         const std::vector<uint16_t>& counterIds,
186                         const std::set<BackendId>& activeBackends);
187 
188     // Setters for the profiling service state
189     void SetCounterValue(uint16_t counterUid, uint32_t value) override;
190     uint32_t AddCounterValue(uint16_t counterUid, uint32_t value) override;
191     uint32_t SubtractCounterValue(uint16_t counterUid, uint32_t value) override;
192     uint32_t IncrementCounterValue(uint16_t counterUid) override;
193 
194     // IProfilingGuidGenerator functions
195     /// Return the next random Guid in the sequence
196     ProfilingDynamicGuid NextGuid() override;
197     /// Create a ProfilingStaticGuid based on a hash of the string
198     ProfilingStaticGuid GenerateStaticId(const std::string& str) override;
199 
200 
201     std::unique_ptr<ISendTimelinePacket> GetSendTimelinePacket() const override;
202 
GetSendCounterPacket()203     ISendCounterPacket& GetSendCounterPacket() override
204     {
205         return m_SendCounterPacket;
206     }
207 
208     static ProfilingDynamicGuid GetNextGuid();
209 
210     static ProfilingStaticGuid GetStaticId(const std::string& str);
211 
212     void ResetGuidGenerator();
213 
IsTimelineReportingEnabled()214     bool IsTimelineReportingEnabled()
215     {
216         return m_TimelineReporting;
217     }
218 
219     void AddLocalPacketHandler(ILocalPacketHandlerSharedPtr localPacketHandler);
220 
221     void NotifyProfilingServiceActive() override; // IProfilingServiceStatus
222     void WaitForProfilingServiceActivation(unsigned int timeout) override; // IProfilingServiceStatus
223 
224 private:
225     //Copy/move constructors/destructors and copy/move assignment operators are deleted
226     ProfilingService(const ProfilingService&) = delete;
227     ProfilingService(ProfilingService&&) = delete;
228     ProfilingService& operator=(const ProfilingService&) = delete;
229     ProfilingService& operator=(ProfilingService&&) = delete;
230 
231     // Initialization/reset functions
232     void Initialize();
233     void InitializeCounterValue(uint16_t counterUid);
234     void Reset();
235     void Stop();
236 
237     // Helper function
238     void CheckCounterUid(uint16_t counterUid) const;
239 
240     // Profiling service components
241     ExternalProfilingOptions           m_Options;
242     std::atomic<bool>                  m_TimelineReporting;
243     CounterDirectory                   m_CounterDirectory;
244     CounterIdMap                       m_CounterIdMap;
245     IProfilingConnectionFactoryPtr     m_ProfilingConnectionFactory;
246     IProfilingConnectionPtr            m_ProfilingConnection;
247     ProfilingStateMachine              m_StateMachine;
248     CounterIndices                     m_CounterIndex;
249     CounterValues                      m_CounterValues;
250     arm::pipe::CommandHandlerRegistry  m_CommandHandlerRegistry;
251     arm::pipe::PacketVersionResolver   m_PacketVersionResolver;
252     CommandHandler                     m_CommandHandler;
253     BufferManager                      m_BufferManager;
254     SendCounterPacket                  m_SendCounterPacket;
255     SendThread                         m_SendThread;
256     SendTimelinePacket                 m_SendTimelinePacket;
257 
258     Holder m_Holder;
259 
260     PeriodicCounterCapture m_PeriodicCounterCapture;
261 
262     ConnectionAcknowledgedCommandHandler      m_ConnectionAcknowledgedCommandHandler;
263     RequestCounterDirectoryCommandHandler     m_RequestCounterDirectoryCommandHandler;
264     PeriodicCounterSelectionCommandHandler    m_PeriodicCounterSelectionCommandHandler;
265     PerJobCounterSelectionCommandHandler      m_PerJobCounterSelectionCommandHandler;
266     ActivateTimelineReportingCommandHandler   m_ActivateTimelineReportingCommandHandler;
267     DeactivateTimelineReportingCommandHandler m_DeactivateTimelineReportingCommandHandler;
268 
269     TimelinePacketWriterFactory m_TimelinePacketWriterFactory;
270     BackendProfilingContext     m_BackendProfilingContexts;
271     uint16_t                    m_MaxGlobalCounterId;
272 
273     static ProfilingGuidGenerator m_GuidGenerator;
274 
275     // Signalling to let external actors know when service is active or not
276     std::mutex m_ServiceActiveMutex;
277     std::condition_variable m_ServiceActiveConditionVariable;
278     bool m_ServiceActive;
279 
280 protected:
281 
282     // Protected methods for testing
SwapProfilingConnectionFactory(ProfilingService & instance,IProfilingConnectionFactory * other,IProfilingConnectionFactory * & backup)283     void SwapProfilingConnectionFactory(ProfilingService& instance,
284                                         IProfilingConnectionFactory* other,
285                                         IProfilingConnectionFactory*& backup)
286     {
287         ARMNN_ASSERT(instance.m_ProfilingConnectionFactory);
288         ARMNN_ASSERT(other);
289 
290         backup = instance.m_ProfilingConnectionFactory.release();
291         instance.m_ProfilingConnectionFactory.reset(other);
292     }
GetProfilingConnection(ProfilingService & instance)293     IProfilingConnection* GetProfilingConnection(ProfilingService& instance)
294     {
295         return instance.m_ProfilingConnection.get();
296     }
TransitionToState(ProfilingService & instance,ProfilingState newState)297     void TransitionToState(ProfilingService& instance, ProfilingState newState)
298     {
299         instance.m_StateMachine.TransitionToState(newState);
300     }
WaitForPacketSent(ProfilingService & instance,uint32_t timeout=1000)301     bool WaitForPacketSent(ProfilingService& instance, uint32_t timeout = 1000)
302     {
303         return instance.m_SendThread.WaitForPacketSent(timeout);
304     }
305 
GetBufferManager(ProfilingService & instance)306     BufferManager& GetBufferManager(ProfilingService& instance)
307     {
308         return instance.m_BufferManager;
309     }
310 };
311 
312 } // namespace profiling
313 
314 } // namespace armnn
315