• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2019 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "PeriodicCounterSelectionCommandHandler.hpp"
7 #include "ProfilingUtils.hpp"
8 
9 #include <armnn/Types.hpp>
10 #include <armnn/utility/NumericCast.hpp>
11 
12 #include <fmt/format.h>
13 
14 #include <vector>
15 
16 namespace armnn
17 {
18 
19 namespace profiling
20 {
21 
ParseData(const arm::pipe::Packet & packet,CaptureData & captureData)22 void PeriodicCounterSelectionCommandHandler::ParseData(const arm::pipe::Packet& packet, CaptureData& captureData)
23 {
24     std::vector<uint16_t> counterIds;
25     uint32_t sizeOfUint32 = armnn::numeric_cast<uint32_t>(sizeof(uint32_t));
26     uint32_t sizeOfUint16 = armnn::numeric_cast<uint32_t>(sizeof(uint16_t));
27     uint32_t offset = 0;
28 
29     if (packet.GetLength() < 4)
30     {
31         // Insufficient packet size
32         return;
33     }
34 
35     // Parse the capture period
36     uint32_t capturePeriod = ReadUint32(packet.GetData(), offset);
37 
38     // Set the capture period
39     captureData.SetCapturePeriod(capturePeriod);
40 
41     // Parse the counter ids
42     unsigned int counters = (packet.GetLength() - 4) / 2;
43     if (counters > 0)
44     {
45         counterIds.reserve(counters);
46         offset += sizeOfUint32;
47         for (unsigned int i = 0; i < counters; ++i)
48         {
49             // Parse the counter id
50             uint16_t counterId = ReadUint16(packet.GetData(), offset);
51             counterIds.emplace_back(counterId);
52             offset += sizeOfUint16;
53         }
54     }
55 
56     // Set the counter ids
57     captureData.SetCounterIds(counterIds);
58 }
59 
operator ()(const arm::pipe::Packet & packet)60 void PeriodicCounterSelectionCommandHandler::operator()(const arm::pipe::Packet& packet)
61 {
62     ProfilingState currentState = m_StateMachine.GetCurrentState();
63     switch (currentState)
64     {
65     case ProfilingState::Uninitialised:
66     case ProfilingState::NotConnected:
67     case ProfilingState::WaitingForAck:
68         throw RuntimeException(fmt::format("Periodic Counter Selection Command Handler invoked while in "
69                                            "an wrong state: {}",
70                                            GetProfilingStateName(currentState)));
71     case ProfilingState::Active:
72     {
73         // Process the packet
74         if (!(packet.GetPacketFamily() == 0u && packet.GetPacketId() == 4u))
75         {
76             throw armnn::InvalidArgumentException(fmt::format("Expected Packet family = 0, id = 4 but "
77                                                               "received family = {}, id = {}",
78                                                               packet.GetPacketFamily(),
79                                                               packet.GetPacketId()));
80         }
81 
82         // Parse the packet to get the capture period and counter UIDs
83         CaptureData captureData;
84         ParseData(packet, captureData);
85 
86         // Get the capture data
87         uint32_t capturePeriod = captureData.GetCapturePeriod();
88         // Validate that the capture period is within the acceptable range.
89         if (capturePeriod > 0  && capturePeriod < LOWEST_CAPTURE_PERIOD)
90         {
91             capturePeriod = LOWEST_CAPTURE_PERIOD;
92         }
93         const std::vector<uint16_t>& counterIds = captureData.GetCounterIds();
94 
95         // Check whether the selected counter UIDs are valid
96         std::vector<uint16_t> validCounterIds;
97         for (uint16_t counterId : counterIds)
98         {
99             // Check whether the counter is registered
100             if (!m_ReadCounterValues.IsCounterRegistered(counterId))
101             {
102                 // Invalid counter UID, ignore it and continue
103                 continue;
104             }
105             // The counter is valid
106             validCounterIds.emplace_back(counterId);
107         }
108 
109         std::sort(validCounterIds.begin(), validCounterIds.end());
110 
111         auto backendIdStart = std::find_if(validCounterIds.begin(), validCounterIds.end(), [&](uint16_t& counterId)
112         {
113             return counterId > m_MaxArmCounterId;
114         });
115 
116         std::set<armnn::BackendId> activeBackends;
117         std::set<uint16_t> backendCounterIds = std::set<uint16_t>(backendIdStart, validCounterIds.end());
118 
119         if (m_BackendCounterMap.size() != 0)
120         {
121             std::set<uint16_t> newCounterIds;
122             std::set<uint16_t> unusedCounterIds;
123 
124             // Get any backend counter ids that is in backendCounterIds but not in m_PrevBackendCounterIds
125             std::set_difference(backendCounterIds.begin(), backendCounterIds.end(),
126                                 m_PrevBackendCounterIds.begin(), m_PrevBackendCounterIds.end(),
127                                 std::inserter(newCounterIds, newCounterIds.begin()));
128 
129             // Get any backend counter ids that is in m_PrevBackendCounterIds but not in backendCounterIds
130             std::set_difference(m_PrevBackendCounterIds.begin(), m_PrevBackendCounterIds.end(),
131                                 backendCounterIds.begin(), backendCounterIds.end(),
132                                 std::inserter(unusedCounterIds, unusedCounterIds.begin()));
133 
134             activeBackends = ProcessBackendCounterIds(capturePeriod, newCounterIds, unusedCounterIds);
135         }
136         else
137         {
138             activeBackends = ProcessBackendCounterIds(capturePeriod, backendCounterIds, {});
139         }
140 
141         // save the new backend counter ids for next time
142         m_PrevBackendCounterIds = backendCounterIds;
143 
144         // Set the capture data with only the valid armnn counter UIDs
145         m_CaptureDataHolder.SetCaptureData(capturePeriod, {validCounterIds.begin(), backendIdStart}, activeBackends);
146 
147         // Echo back the Periodic Counter Selection packet to the Counter Stream Buffer
148         m_SendCounterPacket.SendPeriodicCounterSelectionPacket(capturePeriod, validCounterIds);
149 
150         if (capturePeriod == 0 || validCounterIds.empty())
151         {
152             // No data capture stop the thread
153             m_PeriodicCounterCapture.Stop();
154         }
155         else
156         {
157             // Start the Period Counter Capture thread (if not running already)
158             m_PeriodicCounterCapture.Start();
159         }
160 
161         break;
162     }
163     default:
164         throw RuntimeException(fmt::format("Unknown profiling service state: {}",
165                                            static_cast<int>(currentState)));
166     }
167 }
168 
ProcessBackendCounterIds(const uint32_t capturePeriod,const std::set<uint16_t> newCounterIds,const std::set<uint16_t> unusedCounterIds)169 std::set<armnn::BackendId> PeriodicCounterSelectionCommandHandler::ProcessBackendCounterIds(
170                                                                       const uint32_t capturePeriod,
171                                                                       const std::set<uint16_t> newCounterIds,
172                                                                       const std::set<uint16_t> unusedCounterIds)
173 {
174     std::set<armnn::BackendId> changedBackends;
175     std::set<armnn::BackendId> activeBackends = m_CaptureDataHolder.GetCaptureData().GetActiveBackends();
176 
177     for (uint16_t counterId : newCounterIds)
178     {
179         auto backendId = m_CounterIdMap.GetBackendId(counterId);
180         m_BackendCounterMap[backendId.second].emplace_back(backendId.first);
181         changedBackends.insert(backendId.second);
182     }
183     // Add any new backends to active backends
184     activeBackends.insert(changedBackends.begin(), changedBackends.end());
185 
186     for (uint16_t counterId : unusedCounterIds)
187     {
188         auto backendId = m_CounterIdMap.GetBackendId(counterId);
189         std::vector<uint16_t>& backendCounters = m_BackendCounterMap[backendId.second];
190 
191         backendCounters.erase(std::remove(backendCounters.begin(), backendCounters.end(), backendId.first));
192 
193         if(backendCounters.size() == 0)
194         {
195             // If a backend has no counters associated with it we remove it from active backends and
196             // send a capture period of zero with an empty vector, this will deactivate all the backends counters
197             activeBackends.erase(backendId.second);
198             ActivateBackedCounters(backendId.second, 0, {});
199         }
200         else
201         {
202             changedBackends.insert(backendId.second);
203         }
204     }
205 
206     // If the capture period remains the same we only need to update the backends who's counters have changed
207     if(capturePeriod == m_PrevCapturePeriod)
208     {
209         for (auto backend : changedBackends)
210         {
211             ActivateBackedCounters(backend, capturePeriod, m_BackendCounterMap[backend]);
212         }
213     }
214     // Otherwise update all the backends with the new capture period and any new/unused counters
215     else
216     {
217         for (auto backend : m_BackendCounterMap)
218         {
219             ActivateBackedCounters(backend.first, capturePeriod, backend.second);
220         }
221         if(capturePeriod == 0)
222         {
223             activeBackends = {};
224         }
225         m_PrevCapturePeriod = capturePeriod;
226     }
227 
228     return activeBackends;
229 }
230 
231 } // namespace profiling
232 
233 } // namespace armnn
234