1 //
2 // Copyright © 2019 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "PeriodicCounterSelectionCommandHandler.hpp"
7 #include "ProfilingUtils.hpp"
8
9 #include <client/include/ProfilingOptions.hpp>
10
11 #include <common/include/NumericCast.hpp>
12
13 #include <fmt/format.h>
14
15 #include <vector>
16
17 namespace arm
18 {
19
20 namespace pipe
21 {
22
ParseData(const arm::pipe::Packet & packet,CaptureData & captureData)23 void PeriodicCounterSelectionCommandHandler::ParseData(const arm::pipe::Packet& packet, CaptureData& captureData)
24 {
25 std::vector<uint16_t> counterIds;
26 uint32_t sizeOfUint32 = arm::pipe::numeric_cast<uint32_t>(sizeof(uint32_t));
27 uint32_t sizeOfUint16 = arm::pipe::numeric_cast<uint32_t>(sizeof(uint16_t));
28 uint32_t offset = 0;
29
30 if (packet.GetLength() < 4)
31 {
32 // Insufficient packet size
33 return;
34 }
35
36 // Parse the capture period
37 uint32_t capturePeriod = ReadUint32(packet.GetData(), offset);
38
39 // Set the capture period
40 captureData.SetCapturePeriod(capturePeriod);
41
42 // Parse the counter ids
43 unsigned int counters = (packet.GetLength() - 4) / 2;
44 if (counters > 0)
45 {
46 counterIds.reserve(counters);
47 offset += sizeOfUint32;
48 for (unsigned int i = 0; i < counters; ++i)
49 {
50 // Parse the counter id
51 uint16_t counterId = ReadUint16(packet.GetData(), offset);
52 counterIds.emplace_back(counterId);
53 offset += sizeOfUint16;
54 }
55 }
56
57 // Set the counter ids
58 captureData.SetCounterIds(counterIds);
59 }
60
operator ()(const arm::pipe::Packet & packet)61 void PeriodicCounterSelectionCommandHandler::operator()(const arm::pipe::Packet& packet)
62 {
63 ProfilingState currentState = m_StateMachine.GetCurrentState();
64 switch (currentState)
65 {
66 case ProfilingState::Uninitialised:
67 case ProfilingState::NotConnected:
68 case ProfilingState::WaitingForAck:
69 throw arm::pipe::ProfilingException(fmt::format("Periodic Counter Selection Command Handler invoked while in "
70 "an wrong state: {}",
71 GetProfilingStateName(currentState)));
72 case ProfilingState::Active:
73 {
74 // Process the packet
75 if (!(packet.GetPacketFamily() == 0u && packet.GetPacketId() == 4u))
76 {
77 throw arm::pipe::InvalidArgumentException(fmt::format("Expected Packet family = 0, id = 4 but "
78 "received family = {}, id = {}",
79 packet.GetPacketFamily(),
80 packet.GetPacketId()));
81 }
82
83 // Parse the packet to get the capture period and counter UIDs
84 CaptureData captureData;
85 ParseData(packet, captureData);
86
87 // Get the capture data
88 uint32_t capturePeriod = captureData.GetCapturePeriod();
89 // Validate that the capture period is within the acceptable range.
90 if (capturePeriod > 0 && capturePeriod < arm::pipe::LOWEST_CAPTURE_PERIOD)
91 {
92 capturePeriod = arm::pipe::LOWEST_CAPTURE_PERIOD;
93 }
94 const std::vector<uint16_t>& counterIds = captureData.GetCounterIds();
95
96 // Check whether the selected counter UIDs are valid
97 std::vector<uint16_t> validCounterIds;
98 for (uint16_t counterId : counterIds)
99 {
100 // Check whether the counter is registered
101 if (!m_ReadCounterValues.IsCounterRegistered(counterId))
102 {
103 // Invalid counter UID, ignore it and continue
104 continue;
105 }
106 // The counter is valid
107 validCounterIds.emplace_back(counterId);
108 }
109
110 std::sort(validCounterIds.begin(), validCounterIds.end());
111
112 auto backendIdStart = std::find_if(validCounterIds.begin(), validCounterIds.end(), [&](uint16_t& counterId)
113 {
114 return counterId > m_MaxArmCounterId;
115 });
116
117 std::set<std::string> activeBackends;
118 std::set<uint16_t> backendCounterIds = std::set<uint16_t>(backendIdStart, validCounterIds.end());
119
120 if (m_BackendCounterMap.size() != 0)
121 {
122 std::set<uint16_t> newCounterIds;
123 std::set<uint16_t> unusedCounterIds;
124
125 // Get any backend counter ids that is in backendCounterIds but not in m_PrevBackendCounterIds
126 std::set_difference(backendCounterIds.begin(), backendCounterIds.end(),
127 m_PrevBackendCounterIds.begin(), m_PrevBackendCounterIds.end(),
128 std::inserter(newCounterIds, newCounterIds.begin()));
129
130 // Get any backend counter ids that is in m_PrevBackendCounterIds but not in backendCounterIds
131 std::set_difference(m_PrevBackendCounterIds.begin(), m_PrevBackendCounterIds.end(),
132 backendCounterIds.begin(), backendCounterIds.end(),
133 std::inserter(unusedCounterIds, unusedCounterIds.begin()));
134
135 activeBackends = ProcessBackendCounterIds(capturePeriod, newCounterIds, unusedCounterIds);
136 }
137 else
138 {
139 activeBackends = ProcessBackendCounterIds(capturePeriod, backendCounterIds, {});
140 }
141
142 // save the new backend counter ids for next time
143 m_PrevBackendCounterIds = backendCounterIds;
144
145 // Set the capture data with only the valid armnn counter UIDs
146 m_CaptureDataHolder.SetCaptureData(capturePeriod, {validCounterIds.begin(), backendIdStart}, activeBackends);
147
148 // Echo back the Periodic Counter Selection packet to the Counter Stream Buffer
149 m_SendCounterPacket.SendPeriodicCounterSelectionPacket(capturePeriod, validCounterIds);
150
151 if (capturePeriod == 0 || validCounterIds.empty())
152 {
153 // No data capture stop the thread
154 m_PeriodicCounterCapture.Stop();
155 }
156 else
157 {
158 // Start the Period Counter Capture thread (if not running already)
159 m_PeriodicCounterCapture.Start();
160 }
161
162 break;
163 }
164 default:
165 throw arm::pipe::ProfilingException(fmt::format("Unknown profiling service state: {}",
166 static_cast<int>(currentState)));
167 }
168 }
169
ProcessBackendCounterIds(const uint32_t capturePeriod,const std::set<uint16_t> newCounterIds,const std::set<uint16_t> unusedCounterIds)170 std::set<std::string> PeriodicCounterSelectionCommandHandler::ProcessBackendCounterIds(
171 const uint32_t capturePeriod,
172 const std::set<uint16_t> newCounterIds,
173 const std::set<uint16_t> unusedCounterIds)
174 {
175 std::set<std::string> changedBackends;
176 std::set<std::string> activeBackends = m_CaptureDataHolder.GetCaptureData().GetActiveBackends();
177
178 for (uint16_t counterId : newCounterIds)
179 {
180 auto backendId = m_CounterIdMap.GetBackendId(counterId);
181 m_BackendCounterMap[backendId.second].emplace_back(backendId.first);
182 changedBackends.insert(backendId.second);
183 }
184 // Add any new backends to active backends
185 activeBackends.insert(changedBackends.begin(), changedBackends.end());
186
187 for (uint16_t counterId : unusedCounterIds)
188 {
189 auto backendId = m_CounterIdMap.GetBackendId(counterId);
190 std::vector<uint16_t>& backendCounters = m_BackendCounterMap[backendId.second];
191
192 backendCounters.erase(std::remove(backendCounters.begin(), backendCounters.end(), backendId.first));
193
194 if(backendCounters.size() == 0)
195 {
196 // If a backend has no counters associated with it we remove it from active backends and
197 // send a capture period of zero with an empty vector, this will deactivate all the backends counters
198 activeBackends.erase(backendId.second);
199 ActivateBackendCounters(backendId.second, 0, {});
200 }
201 else
202 {
203 changedBackends.insert(backendId.second);
204 }
205 }
206
207 // If the capture period remains the same we only need to update the backends who's counters have changed
208 if(capturePeriod == m_PrevCapturePeriod)
209 {
210 for (auto backend : changedBackends)
211 {
212 ActivateBackendCounters(backend, capturePeriod, m_BackendCounterMap[backend]);
213 }
214 }
215 // Otherwise update all the backends with the new capture period and any new/unused counters
216 else
217 {
218 for (auto backend : m_BackendCounterMap)
219 {
220 ActivateBackendCounters(backend.first, capturePeriod, backend.second);
221 }
222 if(capturePeriod == 0)
223 {
224 activeBackends = {};
225 }
226 m_PrevCapturePeriod = capturePeriod;
227 }
228
229 return activeBackends;
230 }
231
232 } // namespace pipe
233
234 } // namespace arm
235