1 // 2 // Copyright © 2019 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include "ActivateTimelineReportingCommandHandler.hpp" 9 #include "BufferManager.hpp" 10 #include "CommandHandler.hpp" 11 #include "ConnectionAcknowledgedCommandHandler.hpp" 12 #include "CounterDirectory.hpp" 13 #include "CounterIdMap.hpp" 14 #include "DeactivateTimelineReportingCommandHandler.hpp" 15 #include "ICounterRegistry.hpp" 16 #include "ICounterValues.hpp" 17 #include <armnn/profiling/ILocalPacketHandler.hpp> 18 #include "IProfilingService.hpp" 19 #include "IReportStructure.hpp" 20 #include "PeriodicCounterCapture.hpp" 21 #include "PeriodicCounterSelectionCommandHandler.hpp" 22 #include "PerJobCounterSelectionCommandHandler.hpp" 23 #include "ProfilingConnectionFactory.hpp" 24 #include "ProfilingGuidGenerator.hpp" 25 #include "ProfilingStateMachine.hpp" 26 #include "RequestCounterDirectoryCommandHandler.hpp" 27 #include "SendCounterPacket.hpp" 28 #include "SendThread.hpp" 29 #include "SendTimelinePacket.hpp" 30 #include "TimelinePacketWriterFactory.hpp" 31 #include "INotifyBackends.hpp" 32 #include <armnn/backends/profiling/IBackendProfilingContext.hpp> 33 34 #include <list> 35 36 namespace armnn 37 { 38 39 namespace profiling 40 { 41 // Static constants describing ArmNN's counter UID's 42 static const uint16_t NETWORK_LOADS = 0; 43 static const uint16_t NETWORK_UNLOADS = 1; 44 static const uint16_t REGISTERED_BACKENDS = 2; 45 static const uint16_t UNREGISTERED_BACKENDS = 3; 46 static const uint16_t INFERENCES_RUN = 4; 47 static const uint16_t MAX_ARMNN_COUNTER = INFERENCES_RUN; 48 49 class ProfilingService : public IReadWriteCounterValues, public IProfilingService, public INotifyBackends 50 { 51 public: 52 using ExternalProfilingOptions = IRuntime::CreationOptions::ExternalProfilingOptions; 53 using IProfilingConnectionFactoryPtr = std::unique_ptr<IProfilingConnectionFactory>; 54 using IProfilingConnectionPtr = std::unique_ptr<IProfilingConnection>; 55 using CounterIndices = std::vector<std::atomic<uint32_t>*>; 56 using CounterValues = std::list<std::atomic<uint32_t>>; 57 using BackendProfilingContext = std::unordered_map<BackendId, 58 std::shared_ptr<armnn::profiling::IBackendProfilingContext>>; 59 ProfilingService(Optional<IReportStructure &> reportStructure=EmptyOptional ())60 ProfilingService(Optional<IReportStructure&> reportStructure = EmptyOptional()) 61 : m_Options() 62 , m_TimelineReporting(false) 63 , m_CounterDirectory() 64 , m_ProfilingConnectionFactory(new ProfilingConnectionFactory()) 65 , m_ProfilingConnection() 66 , m_StateMachine() 67 , m_CounterIndex() 68 , m_CounterValues() 69 , m_CommandHandlerRegistry() 70 , m_PacketVersionResolver() 71 , m_CommandHandler(1000, 72 false, 73 m_CommandHandlerRegistry, 74 m_PacketVersionResolver) 75 , m_BufferManager() 76 , m_SendCounterPacket(m_BufferManager) 77 , m_SendThread(m_StateMachine, m_BufferManager, m_SendCounterPacket) 78 , m_SendTimelinePacket(m_BufferManager) 79 , m_PeriodicCounterCapture(m_Holder, m_SendCounterPacket, *this, m_CounterIdMap, m_BackendProfilingContexts) 80 , m_ConnectionAcknowledgedCommandHandler(0, 81 1, 82 m_PacketVersionResolver.ResolvePacketVersion(0, 1).GetEncodedValue(), 83 m_CounterDirectory, 84 m_SendCounterPacket, 85 m_SendTimelinePacket, 86 m_StateMachine, 87 *this, 88 m_BackendProfilingContexts) 89 , m_RequestCounterDirectoryCommandHandler(0, 90 3, 91 m_PacketVersionResolver.ResolvePacketVersion(0, 3).GetEncodedValue(), 92 m_CounterDirectory, 93 m_SendCounterPacket, 94 m_SendTimelinePacket, 95 m_StateMachine) 96 , m_PeriodicCounterSelectionCommandHandler(0, 97 4, 98 m_PacketVersionResolver.ResolvePacketVersion(0, 4).GetEncodedValue(), 99 m_BackendProfilingContexts, 100 m_CounterIdMap, 101 m_Holder, 102 MAX_ARMNN_COUNTER, 103 m_PeriodicCounterCapture, 104 *this, 105 m_SendCounterPacket, 106 m_StateMachine) 107 , m_PerJobCounterSelectionCommandHandler(0, 108 5, 109 m_PacketVersionResolver.ResolvePacketVersion(0, 5).GetEncodedValue(), 110 m_StateMachine) 111 , m_ActivateTimelineReportingCommandHandler(0, 112 6, 113 m_PacketVersionResolver.ResolvePacketVersion(0, 6) 114 .GetEncodedValue(), 115 m_SendTimelinePacket, 116 m_StateMachine, 117 reportStructure, 118 m_TimelineReporting, 119 *this) 120 , m_DeactivateTimelineReportingCommandHandler(0, 121 7, 122 m_PacketVersionResolver.ResolvePacketVersion(0, 7) 123 .GetEncodedValue(), 124 m_TimelineReporting, 125 m_StateMachine, 126 *this) 127 , m_TimelinePacketWriterFactory(m_BufferManager) 128 , m_MaxGlobalCounterId(armnn::profiling::INFERENCES_RUN) 129 , m_ServiceActive(false) 130 { 131 // Register the "Connection Acknowledged" command handler 132 m_CommandHandlerRegistry.RegisterFunctor(&m_ConnectionAcknowledgedCommandHandler); 133 134 // Register the "Request Counter Directory" command handler 135 m_CommandHandlerRegistry.RegisterFunctor(&m_RequestCounterDirectoryCommandHandler); 136 137 // Register the "Periodic Counter Selection" command handler 138 m_CommandHandlerRegistry.RegisterFunctor(&m_PeriodicCounterSelectionCommandHandler); 139 140 // Register the "Per-Job Counter Selection" command handler 141 m_CommandHandlerRegistry.RegisterFunctor(&m_PerJobCounterSelectionCommandHandler); 142 143 m_CommandHandlerRegistry.RegisterFunctor(&m_ActivateTimelineReportingCommandHandler); 144 145 m_CommandHandlerRegistry.RegisterFunctor(&m_DeactivateTimelineReportingCommandHandler); 146 } 147 148 ~ProfilingService(); 149 150 // Resets the profiling options, optionally clears the profiling service entirely 151 void ResetExternalProfilingOptions(const ExternalProfilingOptions& options, bool resetProfilingService = false); 152 ProfilingState ConfigureProfilingService(const ExternalProfilingOptions& options, 153 bool resetProfilingService = false); 154 155 156 // Updates the profiling service, making it transition to a new state if necessary 157 void Update(); 158 159 // Disconnects the profiling service from the external server 160 void Disconnect(); 161 162 // Store a profiling context returned from a backend that support profiling. 163 void AddBackendProfilingContext(const BackendId backendId, 164 std::shared_ptr<armnn::profiling::IBackendProfilingContext> profilingContext); 165 166 // Enable the recording of timeline events and entities 167 void NotifyBackendsForTimelineReporting() override; 168 169 const ICounterDirectory& GetCounterDirectory() const; 170 ICounterRegistry& GetCounterRegistry(); 171 ProfilingState GetCurrentState() const; 172 bool IsCounterRegistered(uint16_t counterUid) const override; 173 uint32_t GetAbsoluteCounterValue(uint16_t counterUid) const override; 174 uint32_t GetDeltaCounterValue(uint16_t counterUid) override; 175 uint16_t GetCounterCount() const override; 176 // counter global/backend mapping functions 177 const ICounterMappings& GetCounterMappings() const override; 178 IRegisterCounterMapping& GetCounterMappingRegistry(); 179 180 // Getters for the profiling service state 181 bool IsProfilingEnabled() const override; 182 183 CaptureData GetCaptureData() override; 184 void SetCaptureData(uint32_t capturePeriod, 185 const std::vector<uint16_t>& counterIds, 186 const std::set<BackendId>& activeBackends); 187 188 // Setters for the profiling service state 189 void SetCounterValue(uint16_t counterUid, uint32_t value) override; 190 uint32_t AddCounterValue(uint16_t counterUid, uint32_t value) override; 191 uint32_t SubtractCounterValue(uint16_t counterUid, uint32_t value) override; 192 uint32_t IncrementCounterValue(uint16_t counterUid) override; 193 194 // IProfilingGuidGenerator functions 195 /// Return the next random Guid in the sequence 196 ProfilingDynamicGuid NextGuid() override; 197 /// Create a ProfilingStaticGuid based on a hash of the string 198 ProfilingStaticGuid GenerateStaticId(const std::string& str) override; 199 200 201 std::unique_ptr<ISendTimelinePacket> GetSendTimelinePacket() const override; 202 GetSendCounterPacket()203 ISendCounterPacket& GetSendCounterPacket() override 204 { 205 return m_SendCounterPacket; 206 } 207 208 static ProfilingDynamicGuid GetNextGuid(); 209 210 static ProfilingStaticGuid GetStaticId(const std::string& str); 211 212 void ResetGuidGenerator(); 213 IsTimelineReportingEnabled()214 bool IsTimelineReportingEnabled() 215 { 216 return m_TimelineReporting; 217 } 218 219 void AddLocalPacketHandler(ILocalPacketHandlerSharedPtr localPacketHandler); 220 221 void NotifyProfilingServiceActive() override; // IProfilingServiceStatus 222 void WaitForProfilingServiceActivation(unsigned int timeout) override; // IProfilingServiceStatus 223 224 private: 225 //Copy/move constructors/destructors and copy/move assignment operators are deleted 226 ProfilingService(const ProfilingService&) = delete; 227 ProfilingService(ProfilingService&&) = delete; 228 ProfilingService& operator=(const ProfilingService&) = delete; 229 ProfilingService& operator=(ProfilingService&&) = delete; 230 231 // Initialization/reset functions 232 void Initialize(); 233 void InitializeCounterValue(uint16_t counterUid); 234 void Reset(); 235 void Stop(); 236 237 // Helper function 238 void CheckCounterUid(uint16_t counterUid) const; 239 240 // Profiling service components 241 ExternalProfilingOptions m_Options; 242 std::atomic<bool> m_TimelineReporting; 243 CounterDirectory m_CounterDirectory; 244 CounterIdMap m_CounterIdMap; 245 IProfilingConnectionFactoryPtr m_ProfilingConnectionFactory; 246 IProfilingConnectionPtr m_ProfilingConnection; 247 ProfilingStateMachine m_StateMachine; 248 CounterIndices m_CounterIndex; 249 CounterValues m_CounterValues; 250 arm::pipe::CommandHandlerRegistry m_CommandHandlerRegistry; 251 arm::pipe::PacketVersionResolver m_PacketVersionResolver; 252 CommandHandler m_CommandHandler; 253 BufferManager m_BufferManager; 254 SendCounterPacket m_SendCounterPacket; 255 SendThread m_SendThread; 256 SendTimelinePacket m_SendTimelinePacket; 257 258 Holder m_Holder; 259 260 PeriodicCounterCapture m_PeriodicCounterCapture; 261 262 ConnectionAcknowledgedCommandHandler m_ConnectionAcknowledgedCommandHandler; 263 RequestCounterDirectoryCommandHandler m_RequestCounterDirectoryCommandHandler; 264 PeriodicCounterSelectionCommandHandler m_PeriodicCounterSelectionCommandHandler; 265 PerJobCounterSelectionCommandHandler m_PerJobCounterSelectionCommandHandler; 266 ActivateTimelineReportingCommandHandler m_ActivateTimelineReportingCommandHandler; 267 DeactivateTimelineReportingCommandHandler m_DeactivateTimelineReportingCommandHandler; 268 269 TimelinePacketWriterFactory m_TimelinePacketWriterFactory; 270 BackendProfilingContext m_BackendProfilingContexts; 271 uint16_t m_MaxGlobalCounterId; 272 273 static ProfilingGuidGenerator m_GuidGenerator; 274 275 // Signalling to let external actors know when service is active or not 276 std::mutex m_ServiceActiveMutex; 277 std::condition_variable m_ServiceActiveConditionVariable; 278 bool m_ServiceActive; 279 280 protected: 281 282 // Protected methods for testing SwapProfilingConnectionFactory(ProfilingService & instance,IProfilingConnectionFactory * other,IProfilingConnectionFactory * & backup)283 void SwapProfilingConnectionFactory(ProfilingService& instance, 284 IProfilingConnectionFactory* other, 285 IProfilingConnectionFactory*& backup) 286 { 287 ARMNN_ASSERT(instance.m_ProfilingConnectionFactory); 288 ARMNN_ASSERT(other); 289 290 backup = instance.m_ProfilingConnectionFactory.release(); 291 instance.m_ProfilingConnectionFactory.reset(other); 292 } GetProfilingConnection(ProfilingService & instance)293 IProfilingConnection* GetProfilingConnection(ProfilingService& instance) 294 { 295 return instance.m_ProfilingConnection.get(); 296 } TransitionToState(ProfilingService & instance,ProfilingState newState)297 void TransitionToState(ProfilingService& instance, ProfilingState newState) 298 { 299 instance.m_StateMachine.TransitionToState(newState); 300 } WaitForPacketSent(ProfilingService & instance,uint32_t timeout=1000)301 bool WaitForPacketSent(ProfilingService& instance, uint32_t timeout = 1000) 302 { 303 return instance.m_SendThread.WaitForPacketSent(timeout); 304 } 305 GetBufferManager(ProfilingService & instance)306 BufferManager& GetBufferManager(ProfilingService& instance) 307 { 308 return instance.m_BufferManager; 309 } 310 }; 311 312 } // namespace profiling 313 314 } // namespace armnn 315