1 //
2 // Copyright © 2019 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "PeriodicCounterSelectionCommandHandler.hpp"
7 #include "ProfilingUtils.hpp"
8
9 #include <armnn/Types.hpp>
10 #include <armnn/utility/NumericCast.hpp>
11
12 #include <fmt/format.h>
13
14 #include <vector>
15
16 namespace armnn
17 {
18
19 namespace profiling
20 {
21
ParseData(const arm::pipe::Packet & packet,CaptureData & captureData)22 void PeriodicCounterSelectionCommandHandler::ParseData(const arm::pipe::Packet& packet, CaptureData& captureData)
23 {
24 std::vector<uint16_t> counterIds;
25 uint32_t sizeOfUint32 = armnn::numeric_cast<uint32_t>(sizeof(uint32_t));
26 uint32_t sizeOfUint16 = armnn::numeric_cast<uint32_t>(sizeof(uint16_t));
27 uint32_t offset = 0;
28
29 if (packet.GetLength() < 4)
30 {
31 // Insufficient packet size
32 return;
33 }
34
35 // Parse the capture period
36 uint32_t capturePeriod = ReadUint32(packet.GetData(), offset);
37
38 // Set the capture period
39 captureData.SetCapturePeriod(capturePeriod);
40
41 // Parse the counter ids
42 unsigned int counters = (packet.GetLength() - 4) / 2;
43 if (counters > 0)
44 {
45 counterIds.reserve(counters);
46 offset += sizeOfUint32;
47 for (unsigned int i = 0; i < counters; ++i)
48 {
49 // Parse the counter id
50 uint16_t counterId = ReadUint16(packet.GetData(), offset);
51 counterIds.emplace_back(counterId);
52 offset += sizeOfUint16;
53 }
54 }
55
56 // Set the counter ids
57 captureData.SetCounterIds(counterIds);
58 }
59
operator ()(const arm::pipe::Packet & packet)60 void PeriodicCounterSelectionCommandHandler::operator()(const arm::pipe::Packet& packet)
61 {
62 ProfilingState currentState = m_StateMachine.GetCurrentState();
63 switch (currentState)
64 {
65 case ProfilingState::Uninitialised:
66 case ProfilingState::NotConnected:
67 case ProfilingState::WaitingForAck:
68 throw RuntimeException(fmt::format("Periodic Counter Selection Command Handler invoked while in "
69 "an wrong state: {}",
70 GetProfilingStateName(currentState)));
71 case ProfilingState::Active:
72 {
73 // Process the packet
74 if (!(packet.GetPacketFamily() == 0u && packet.GetPacketId() == 4u))
75 {
76 throw armnn::InvalidArgumentException(fmt::format("Expected Packet family = 0, id = 4 but "
77 "received family = {}, id = {}",
78 packet.GetPacketFamily(),
79 packet.GetPacketId()));
80 }
81
82 // Parse the packet to get the capture period and counter UIDs
83 CaptureData captureData;
84 ParseData(packet, captureData);
85
86 // Get the capture data
87 uint32_t capturePeriod = captureData.GetCapturePeriod();
88 // Validate that the capture period is within the acceptable range.
89 if (capturePeriod > 0 && capturePeriod < LOWEST_CAPTURE_PERIOD)
90 {
91 capturePeriod = LOWEST_CAPTURE_PERIOD;
92 }
93 const std::vector<uint16_t>& counterIds = captureData.GetCounterIds();
94
95 // Check whether the selected counter UIDs are valid
96 std::vector<uint16_t> validCounterIds;
97 for (uint16_t counterId : counterIds)
98 {
99 // Check whether the counter is registered
100 if (!m_ReadCounterValues.IsCounterRegistered(counterId))
101 {
102 // Invalid counter UID, ignore it and continue
103 continue;
104 }
105 // The counter is valid
106 validCounterIds.emplace_back(counterId);
107 }
108
109 std::sort(validCounterIds.begin(), validCounterIds.end());
110
111 auto backendIdStart = std::find_if(validCounterIds.begin(), validCounterIds.end(), [&](uint16_t& counterId)
112 {
113 return counterId > m_MaxArmCounterId;
114 });
115
116 std::set<armnn::BackendId> activeBackends;
117 std::set<uint16_t> backendCounterIds = std::set<uint16_t>(backendIdStart, validCounterIds.end());
118
119 if (m_BackendCounterMap.size() != 0)
120 {
121 std::set<uint16_t> newCounterIds;
122 std::set<uint16_t> unusedCounterIds;
123
124 // Get any backend counter ids that is in backendCounterIds but not in m_PrevBackendCounterIds
125 std::set_difference(backendCounterIds.begin(), backendCounterIds.end(),
126 m_PrevBackendCounterIds.begin(), m_PrevBackendCounterIds.end(),
127 std::inserter(newCounterIds, newCounterIds.begin()));
128
129 // Get any backend counter ids that is in m_PrevBackendCounterIds but not in backendCounterIds
130 std::set_difference(m_PrevBackendCounterIds.begin(), m_PrevBackendCounterIds.end(),
131 backendCounterIds.begin(), backendCounterIds.end(),
132 std::inserter(unusedCounterIds, unusedCounterIds.begin()));
133
134 activeBackends = ProcessBackendCounterIds(capturePeriod, newCounterIds, unusedCounterIds);
135 }
136 else
137 {
138 activeBackends = ProcessBackendCounterIds(capturePeriod, backendCounterIds, {});
139 }
140
141 // save the new backend counter ids for next time
142 m_PrevBackendCounterIds = backendCounterIds;
143
144 // Set the capture data with only the valid armnn counter UIDs
145 m_CaptureDataHolder.SetCaptureData(capturePeriod, {validCounterIds.begin(), backendIdStart}, activeBackends);
146
147 // Echo back the Periodic Counter Selection packet to the Counter Stream Buffer
148 m_SendCounterPacket.SendPeriodicCounterSelectionPacket(capturePeriod, validCounterIds);
149
150 if (capturePeriod == 0 || validCounterIds.empty())
151 {
152 // No data capture stop the thread
153 m_PeriodicCounterCapture.Stop();
154 }
155 else
156 {
157 // Start the Period Counter Capture thread (if not running already)
158 m_PeriodicCounterCapture.Start();
159 }
160
161 break;
162 }
163 default:
164 throw RuntimeException(fmt::format("Unknown profiling service state: {}",
165 static_cast<int>(currentState)));
166 }
167 }
168
ProcessBackendCounterIds(const uint32_t capturePeriod,const std::set<uint16_t> newCounterIds,const std::set<uint16_t> unusedCounterIds)169 std::set<armnn::BackendId> PeriodicCounterSelectionCommandHandler::ProcessBackendCounterIds(
170 const uint32_t capturePeriod,
171 const std::set<uint16_t> newCounterIds,
172 const std::set<uint16_t> unusedCounterIds)
173 {
174 std::set<armnn::BackendId> changedBackends;
175 std::set<armnn::BackendId> activeBackends = m_CaptureDataHolder.GetCaptureData().GetActiveBackends();
176
177 for (uint16_t counterId : newCounterIds)
178 {
179 auto backendId = m_CounterIdMap.GetBackendId(counterId);
180 m_BackendCounterMap[backendId.second].emplace_back(backendId.first);
181 changedBackends.insert(backendId.second);
182 }
183 // Add any new backends to active backends
184 activeBackends.insert(changedBackends.begin(), changedBackends.end());
185
186 for (uint16_t counterId : unusedCounterIds)
187 {
188 auto backendId = m_CounterIdMap.GetBackendId(counterId);
189 std::vector<uint16_t>& backendCounters = m_BackendCounterMap[backendId.second];
190
191 backendCounters.erase(std::remove(backendCounters.begin(), backendCounters.end(), backendId.first));
192
193 if(backendCounters.size() == 0)
194 {
195 // If a backend has no counters associated with it we remove it from active backends and
196 // send a capture period of zero with an empty vector, this will deactivate all the backends counters
197 activeBackends.erase(backendId.second);
198 ActivateBackedCounters(backendId.second, 0, {});
199 }
200 else
201 {
202 changedBackends.insert(backendId.second);
203 }
204 }
205
206 // If the capture period remains the same we only need to update the backends who's counters have changed
207 if(capturePeriod == m_PrevCapturePeriod)
208 {
209 for (auto backend : changedBackends)
210 {
211 ActivateBackedCounters(backend, capturePeriod, m_BackendCounterMap[backend]);
212 }
213 }
214 // Otherwise update all the backends with the new capture period and any new/unused counters
215 else
216 {
217 for (auto backend : m_BackendCounterMap)
218 {
219 ActivateBackedCounters(backend.first, capturePeriod, backend.second);
220 }
221 if(capturePeriod == 0)
222 {
223 activeBackends = {};
224 }
225 m_PrevCapturePeriod = capturePeriod;
226 }
227
228 return activeBackends;
229 }
230
231 } // namespace profiling
232
233 } // namespace armnn
234