1 //
2 // Copyright © 2019 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "ProfilingService.hpp"
7
8 #include <armnn/BackendId.hpp>
9 #include <armnn/Logging.hpp>
10 #include <armnn/utility/NumericCast.hpp>
11
12 #include <common/include/SocketConnectionException.hpp>
13
14 #include <fmt/format.h>
15
16 namespace armnn
17 {
18
19 namespace profiling
20 {
21
22 ProfilingGuidGenerator ProfilingService::m_GuidGenerator;
23
GetNextGuid()24 ProfilingDynamicGuid ProfilingService::GetNextGuid()
25 {
26 return m_GuidGenerator.NextGuid();
27 }
28
GetStaticId(const std::string & str)29 ProfilingStaticGuid ProfilingService::GetStaticId(const std::string& str)
30 {
31 return m_GuidGenerator.GenerateStaticId(str);
32 }
33
ResetGuidGenerator()34 void ProfilingService::ResetGuidGenerator()
35 {
36 m_GuidGenerator.Reset();
37 }
38
ResetExternalProfilingOptions(const ExternalProfilingOptions & options,bool resetProfilingService)39 void ProfilingService::ResetExternalProfilingOptions(const ExternalProfilingOptions& options,
40 bool resetProfilingService)
41 {
42 // Update the profiling options
43 m_Options = options;
44 m_TimelineReporting = options.m_TimelineEnabled;
45 m_ConnectionAcknowledgedCommandHandler.setTimelineEnabled(options.m_TimelineEnabled);
46
47 // Check if the profiling service needs to be reset
48 if (resetProfilingService)
49 {
50 // Reset the profiling service
51 Reset();
52 }
53 }
54
IsProfilingEnabled() const55 bool ProfilingService::IsProfilingEnabled() const
56 {
57 return m_Options.m_EnableProfiling;
58 }
59
ConfigureProfilingService(const ExternalProfilingOptions & options,bool resetProfilingService)60 ProfilingState ProfilingService::ConfigureProfilingService(
61 const ExternalProfilingOptions& options,
62 bool resetProfilingService)
63 {
64 ResetExternalProfilingOptions(options, resetProfilingService);
65 ProfilingState currentState = m_StateMachine.GetCurrentState();
66 if (options.m_EnableProfiling)
67 {
68 switch (currentState)
69 {
70 case ProfilingState::Uninitialised:
71 Update(); // should transition to NotConnected
72 Update(); // will either stay in NotConnected because there is no server
73 // or will enter WaitingForAck.
74 currentState = m_StateMachine.GetCurrentState();
75 if (currentState == ProfilingState::WaitingForAck)
76 {
77 Update(); // poke it again to send out the metadata packet
78 }
79 currentState = m_StateMachine.GetCurrentState();
80 return currentState;
81 case ProfilingState::NotConnected:
82 Update(); // will either stay in NotConnected because there is no server
83 // or will enter WaitingForAck
84 currentState = m_StateMachine.GetCurrentState();
85 if (currentState == ProfilingState::WaitingForAck)
86 {
87 Update(); // poke it again to send out the metadata packet
88 }
89 currentState = m_StateMachine.GetCurrentState();
90 return currentState;
91 default:
92 return currentState;
93 }
94 }
95 else
96 {
97 // Make sure profiling is shutdown
98 switch (currentState)
99 {
100 case ProfilingState::Uninitialised:
101 case ProfilingState::NotConnected:
102 return currentState;
103 default:
104 Stop();
105 return m_StateMachine.GetCurrentState();
106 }
107 }
108 }
109
Update()110 void ProfilingService::Update()
111 {
112 if (!m_Options.m_EnableProfiling)
113 {
114 // Don't run if profiling is disabled
115 return;
116 }
117
118 ProfilingState currentState = m_StateMachine.GetCurrentState();
119 switch (currentState)
120 {
121 case ProfilingState::Uninitialised:
122
123 // Initialize the profiling service
124 Initialize();
125
126 // Move to the next state
127 m_StateMachine.TransitionToState(ProfilingState::NotConnected);
128 break;
129 case ProfilingState::NotConnected:
130 // Stop the command thread (if running)
131 m_CommandHandler.Stop();
132
133 // Stop the send thread (if running)
134 m_SendThread.Stop(false);
135
136 // Stop the periodic counter capture thread (if running)
137 m_PeriodicCounterCapture.Stop();
138
139 // Reset any existing profiling connection
140 m_ProfilingConnection.reset();
141
142 try
143 {
144 // Setup the profiling connection
145 ARMNN_ASSERT(m_ProfilingConnectionFactory);
146 m_ProfilingConnection = m_ProfilingConnectionFactory->GetProfilingConnection(m_Options);
147 }
148 catch (const Exception& e)
149 {
150 ARMNN_LOG(warning) << "An error has occurred when creating the profiling connection: "
151 << e.what();
152 }
153 catch (const arm::pipe::SocketConnectionException& e)
154 {
155 ARMNN_LOG(warning) << "An error has occurred when creating the profiling connection ["
156 << e.what() << "] on socket [" << e.GetSocketFd() << "].";
157 }
158
159 // Move to the next state
160 m_StateMachine.TransitionToState(m_ProfilingConnection
161 ? ProfilingState::WaitingForAck // Profiling connection obtained, wait for ack
162 : ProfilingState::NotConnected); // Profiling connection failed, stay in the
163 // "NotConnected" state
164 break;
165 case ProfilingState::WaitingForAck:
166 ARMNN_ASSERT(m_ProfilingConnection);
167
168 // Start the command thread
169 m_CommandHandler.Start(*m_ProfilingConnection);
170
171 // Start the send thread, while in "WaitingForAck" state it'll send out a "Stream MetaData" packet waiting for
172 // a valid "Connection Acknowledged" packet confirming the connection
173 m_SendThread.Start(*m_ProfilingConnection);
174
175 // The connection acknowledged command handler will automatically transition the state to "Active" once a
176 // valid "Connection Acknowledged" packet has been received
177
178 break;
179 case ProfilingState::Active:
180
181 // The period counter capture thread is started by the Periodic Counter Selection command handler upon
182 // request by an external profiling service
183
184 break;
185 default:
186 throw RuntimeException(fmt::format("Unknown profiling service state: {}",
187 static_cast<int>(currentState)));
188 }
189 }
190
Disconnect()191 void ProfilingService::Disconnect()
192 {
193 ProfilingState currentState = m_StateMachine.GetCurrentState();
194 switch (currentState)
195 {
196 case ProfilingState::Uninitialised:
197 case ProfilingState::NotConnected:
198 case ProfilingState::WaitingForAck:
199 return; // NOP
200 case ProfilingState::Active:
201 // Stop the command thread (if running)
202 Stop();
203
204 break;
205 default:
206 throw RuntimeException(fmt::format("Unknown profiling service state: {}",
207 static_cast<int>(currentState)));
208 }
209 }
210
211 // Store a profiling context returned from a backend that support profiling, and register its counters
AddBackendProfilingContext(const BackendId backendId,std::shared_ptr<armnn::profiling::IBackendProfilingContext> profilingContext)212 void ProfilingService::AddBackendProfilingContext(const BackendId backendId,
213 std::shared_ptr<armnn::profiling::IBackendProfilingContext> profilingContext)
214 {
215 ARMNN_ASSERT(profilingContext != nullptr);
216 // Register the backend counters
217 m_MaxGlobalCounterId = profilingContext->RegisterCounters(m_MaxGlobalCounterId);
218 m_BackendProfilingContexts.emplace(backendId, std::move(profilingContext));
219 }
GetCounterDirectory() const220 const ICounterDirectory& ProfilingService::GetCounterDirectory() const
221 {
222 return m_CounterDirectory;
223 }
224
GetCounterRegistry()225 ICounterRegistry& ProfilingService::GetCounterRegistry()
226 {
227 return m_CounterDirectory;
228 }
229
GetCurrentState() const230 ProfilingState ProfilingService::GetCurrentState() const
231 {
232 return m_StateMachine.GetCurrentState();
233 }
234
GetCounterCount() const235 uint16_t ProfilingService::GetCounterCount() const
236 {
237 return m_CounterDirectory.GetCounterCount();
238 }
239
IsCounterRegistered(uint16_t counterUid) const240 bool ProfilingService::IsCounterRegistered(uint16_t counterUid) const
241 {
242 return m_CounterDirectory.IsCounterRegistered(counterUid);
243 }
244
GetAbsoluteCounterValue(uint16_t counterUid) const245 uint32_t ProfilingService::GetAbsoluteCounterValue(uint16_t counterUid) const
246 {
247 CheckCounterUid(counterUid);
248 std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
249 ARMNN_ASSERT(counterValuePtr);
250 return counterValuePtr->load(std::memory_order::memory_order_relaxed);
251 }
252
GetDeltaCounterValue(uint16_t counterUid)253 uint32_t ProfilingService::GetDeltaCounterValue(uint16_t counterUid)
254 {
255 CheckCounterUid(counterUid);
256 std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
257 ARMNN_ASSERT(counterValuePtr);
258 const uint32_t counterValue = counterValuePtr->load(std::memory_order::memory_order_relaxed);
259 SubtractCounterValue(counterUid, counterValue);
260 return counterValue;
261 }
262
GetCounterMappings() const263 const ICounterMappings& ProfilingService::GetCounterMappings() const
264 {
265 return m_CounterIdMap;
266 }
267
GetCounterMappingRegistry()268 IRegisterCounterMapping& ProfilingService::GetCounterMappingRegistry()
269 {
270 return m_CounterIdMap;
271 }
272
GetCaptureData()273 CaptureData ProfilingService::GetCaptureData()
274 {
275 return m_Holder.GetCaptureData();
276 }
277
SetCaptureData(uint32_t capturePeriod,const std::vector<uint16_t> & counterIds,const std::set<BackendId> & activeBackends)278 void ProfilingService::SetCaptureData(uint32_t capturePeriod,
279 const std::vector<uint16_t>& counterIds,
280 const std::set<BackendId>& activeBackends)
281 {
282 m_Holder.SetCaptureData(capturePeriod, counterIds, activeBackends);
283 }
284
SetCounterValue(uint16_t counterUid,uint32_t value)285 void ProfilingService::SetCounterValue(uint16_t counterUid, uint32_t value)
286 {
287 CheckCounterUid(counterUid);
288 std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
289 ARMNN_ASSERT(counterValuePtr);
290 counterValuePtr->store(value, std::memory_order::memory_order_relaxed);
291 }
292
AddCounterValue(uint16_t counterUid,uint32_t value)293 uint32_t ProfilingService::AddCounterValue(uint16_t counterUid, uint32_t value)
294 {
295 CheckCounterUid(counterUid);
296 std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
297 ARMNN_ASSERT(counterValuePtr);
298 return counterValuePtr->fetch_add(value, std::memory_order::memory_order_relaxed);
299 }
300
SubtractCounterValue(uint16_t counterUid,uint32_t value)301 uint32_t ProfilingService::SubtractCounterValue(uint16_t counterUid, uint32_t value)
302 {
303 CheckCounterUid(counterUid);
304 std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
305 ARMNN_ASSERT(counterValuePtr);
306 return counterValuePtr->fetch_sub(value, std::memory_order::memory_order_relaxed);
307 }
308
IncrementCounterValue(uint16_t counterUid)309 uint32_t ProfilingService::IncrementCounterValue(uint16_t counterUid)
310 {
311 CheckCounterUid(counterUid);
312 std::atomic<uint32_t>* counterValuePtr = m_CounterIndex.at(counterUid);
313 ARMNN_ASSERT(counterValuePtr);
314 return counterValuePtr->operator++(std::memory_order::memory_order_relaxed);
315 }
316
NextGuid()317 ProfilingDynamicGuid ProfilingService::NextGuid()
318 {
319 return ProfilingService::GetNextGuid();
320 }
321
GenerateStaticId(const std::string & str)322 ProfilingStaticGuid ProfilingService::GenerateStaticId(const std::string& str)
323 {
324 return ProfilingService::GetStaticId(str);
325 }
326
GetSendTimelinePacket() const327 std::unique_ptr<ISendTimelinePacket> ProfilingService::GetSendTimelinePacket() const
328 {
329 return m_TimelinePacketWriterFactory.GetSendTimelinePacket();
330 }
331
Initialize()332 void ProfilingService::Initialize()
333 {
334 // Register a category for the basic runtime counters
335 if (!m_CounterDirectory.IsCategoryRegistered("ArmNN_Runtime"))
336 {
337 m_CounterDirectory.RegisterCategory("ArmNN_Runtime");
338 }
339
340 // Register a counter for the number of Network loads
341 if (!m_CounterDirectory.IsCounterRegistered("Network loads"))
342 {
343 const Counter* loadedNetworksCounter =
344 m_CounterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
345 armnn::profiling::NETWORK_LOADS,
346 "ArmNN_Runtime",
347 0,
348 0,
349 1.f,
350 "Network loads",
351 "The number of networks loaded at runtime",
352 std::string("networks"));
353 ARMNN_ASSERT(loadedNetworksCounter);
354 InitializeCounterValue(loadedNetworksCounter->m_Uid);
355 }
356 // Register a counter for the number of unloaded networks
357 if (!m_CounterDirectory.IsCounterRegistered("Network unloads"))
358 {
359 const Counter* unloadedNetworksCounter =
360 m_CounterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
361 armnn::profiling::NETWORK_UNLOADS,
362 "ArmNN_Runtime",
363 0,
364 0,
365 1.f,
366 "Network unloads",
367 "The number of networks unloaded at runtime",
368 std::string("networks"));
369 ARMNN_ASSERT(unloadedNetworksCounter);
370 InitializeCounterValue(unloadedNetworksCounter->m_Uid);
371 }
372 // Register a counter for the number of registered backends
373 if (!m_CounterDirectory.IsCounterRegistered("Backends registered"))
374 {
375 const Counter* registeredBackendsCounter =
376 m_CounterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
377 armnn::profiling::REGISTERED_BACKENDS,
378 "ArmNN_Runtime",
379 0,
380 0,
381 1.f,
382 "Backends registered",
383 "The number of registered backends",
384 std::string("backends"));
385 ARMNN_ASSERT(registeredBackendsCounter);
386 InitializeCounterValue(registeredBackendsCounter->m_Uid);
387
388 // Due to backends being registered before the profiling service becomes active,
389 // we need to set the counter to the correct value here
390 SetCounterValue(armnn::profiling::REGISTERED_BACKENDS, static_cast<uint32_t>(BackendRegistryInstance().Size()));
391 }
392 // Register a counter for the number of registered backends
393 if (!m_CounterDirectory.IsCounterRegistered("Backends unregistered"))
394 {
395 const Counter* unregisteredBackendsCounter =
396 m_CounterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
397 armnn::profiling::UNREGISTERED_BACKENDS,
398 "ArmNN_Runtime",
399 0,
400 0,
401 1.f,
402 "Backends unregistered",
403 "The number of unregistered backends",
404 std::string("backends"));
405 ARMNN_ASSERT(unregisteredBackendsCounter);
406 InitializeCounterValue(unregisteredBackendsCounter->m_Uid);
407 }
408 // Register a counter for the number of inferences run
409 if (!m_CounterDirectory.IsCounterRegistered("Inferences run"))
410 {
411 const Counter* inferencesRunCounter =
412 m_CounterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
413 armnn::profiling::INFERENCES_RUN,
414 "ArmNN_Runtime",
415 0,
416 0,
417 1.f,
418 "Inferences run",
419 "The number of inferences run",
420 std::string("inferences"));
421 ARMNN_ASSERT(inferencesRunCounter);
422 InitializeCounterValue(inferencesRunCounter->m_Uid);
423 }
424 }
425
InitializeCounterValue(uint16_t counterUid)426 void ProfilingService::InitializeCounterValue(uint16_t counterUid)
427 {
428 // Increase the size of the counter index if necessary
429 if (counterUid >= m_CounterIndex.size())
430 {
431 m_CounterIndex.resize(armnn::numeric_cast<size_t>(counterUid) + 1);
432 }
433
434 // Create a new atomic counter and add it to the list
435 m_CounterValues.emplace_back(0);
436
437 // Register the new counter to the counter index for quick access
438 std::atomic<uint32_t>* counterValuePtr = &(m_CounterValues.back());
439 m_CounterIndex.at(counterUid) = counterValuePtr;
440 }
441
Reset()442 void ProfilingService::Reset()
443 {
444 // Stop the profiling service...
445 Stop();
446
447 // ...then delete all the counter data and configuration...
448 m_CounterIndex.clear();
449 m_CounterValues.clear();
450 m_CounterDirectory.Clear();
451 m_CounterIdMap.Reset();
452 m_BufferManager.Reset();
453
454 // ...finally reset the profiling state machine
455 m_StateMachine.Reset();
456 m_BackendProfilingContexts.clear();
457 m_MaxGlobalCounterId = armnn::profiling::MAX_ARMNN_COUNTER;
458 }
459
Stop()460 void ProfilingService::Stop()
461 {
462 { // only lock when we are updating the inference completed variable
463 std::unique_lock<std::mutex> lck(m_ServiceActiveMutex);
464 m_ServiceActive = false;
465 }
466 // The order in which we reset/stop the components is not trivial!
467 // First stop the producing threads
468 // Command Handler first as it is responsible for launching then Periodic Counter capture thread
469 m_CommandHandler.Stop();
470 m_PeriodicCounterCapture.Stop();
471 // The the consuming thread
472 m_SendThread.Stop(false);
473
474 // ...then close and destroy the profiling connection...
475 if (m_ProfilingConnection != nullptr && m_ProfilingConnection->IsOpen())
476 {
477 m_ProfilingConnection->Close();
478 }
479 m_ProfilingConnection.reset();
480
481 // ...then move to the "NotConnected" state
482 m_StateMachine.TransitionToState(ProfilingState::NotConnected);
483 }
484
CheckCounterUid(uint16_t counterUid) const485 inline void ProfilingService::CheckCounterUid(uint16_t counterUid) const
486 {
487 if (!IsCounterRegistered(counterUid))
488 {
489 throw InvalidArgumentException(fmt::format("Counter UID {} is not registered", counterUid));
490 }
491 }
492
NotifyBackendsForTimelineReporting()493 void ProfilingService::NotifyBackendsForTimelineReporting()
494 {
495 BackendProfilingContext::iterator it = m_BackendProfilingContexts.begin();
496 while (it != m_BackendProfilingContexts.end())
497 {
498 auto& backendProfilingContext = it->second;
499 backendProfilingContext->EnableTimelineReporting(m_TimelineReporting);
500 // Increment the Iterator to point to next entry
501 it++;
502 }
503 }
504
NotifyProfilingServiceActive()505 void ProfilingService::NotifyProfilingServiceActive()
506 {
507 { // only lock when we are updating the inference completed variable
508 std::unique_lock<std::mutex> lck(m_ServiceActiveMutex);
509 m_ServiceActive = true;
510 }
511 m_ServiceActiveConditionVariable.notify_one();
512 }
513
WaitForProfilingServiceActivation(unsigned int timeout)514 void ProfilingService::WaitForProfilingServiceActivation(unsigned int timeout)
515 {
516 std::unique_lock<std::mutex> lck(m_ServiceActiveMutex);
517
518 auto start = std::chrono::high_resolution_clock::now();
519 // Here we we will go back to sleep after a spurious wake up if
520 // m_InferenceCompleted is not yet true.
521 if (!m_ServiceActiveConditionVariable.wait_for(lck,
522 std::chrono::milliseconds(timeout),
523 [&]{return m_ServiceActive == true;}))
524 {
525 if (m_ServiceActive == true)
526 {
527 return;
528 }
529 auto finish = std::chrono::high_resolution_clock::now();
530 std::chrono::duration<double, std::milli> elapsed = finish - start;
531 std::stringstream ss;
532 ss << "Timed out waiting on profiling service activation for " << elapsed.count() << " ms";
533 ARMNN_LOG(warning) << ss.str();
534 }
535 return;
536 }
537
~ProfilingService()538 ProfilingService::~ProfilingService()
539 {
540 Stop();
541 }
542 } // namespace profiling
543
544 } // namespace armnn
545