• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2020 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #define LOG_TAG "ArmnnDriver"
7 
8 #include "RequestThread_1_3.hpp"
9 
10 #include "ArmnnPreparedModel_1_3.hpp"
11 
12 #include <armnn/utility/Assert.hpp>
13 
14 #include <log/log.h>
15 
16 using namespace android;
17 
18 namespace armnn_driver
19 {
20 
21 template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename CallbackContext>
RequestThread_1_3()22 RequestThread_1_3<PreparedModel, HalVersion, CallbackContext>::RequestThread_1_3()
23 {
24     ALOGV("RequestThread_1_3::RequestThread_1_3()");
25     m_Thread = std::make_unique<std::thread>(&RequestThread_1_3::Process, this);
26 }
27 
28 template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename CallbackContext>
~RequestThread_1_3()29 RequestThread_1_3<PreparedModel, HalVersion, CallbackContext>::~RequestThread_1_3()
30 {
31     ALOGV("RequestThread_1_3::~RequestThread_1_3()");
32 
33     try
34     {
35         // Coverity fix: The following code may throw an exception of type std::length_error.
36 
37         // This code is meant to to terminate the inner thread gracefully by posting an EXIT message
38         // to the thread's message queue. However, according to Coverity, this code could throw an exception and fail.
39         // Since only one static instance of RequestThread is used in the driver (in ArmnnPreparedModel),
40         // this destructor is called only when the application has been closed, which means that
41         // the inner thread will be terminated anyway, although abruptly, in the event that the destructor code throws.
42         // Wrapping the destructor's code with a try-catch block simply fixes the Coverity bug.
43 
44         // Post an EXIT message to the thread
45         std::shared_ptr<AsyncExecuteData> nulldata(nullptr);
46         auto pMsg = std::make_shared<ThreadMsg>(ThreadMsgType::EXIT, nulldata);
47         PostMsg(pMsg);
48         // Wait for the thread to terminate, it is deleted automatically
49         m_Thread->join();
50     }
51     catch (const std::exception&) { } // Swallow any exception.
52 }
53 
54 template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename CallbackContext>
PostMsg(PreparedModel<HalVersion> * model,std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> & memPools,std::shared_ptr<armnn::InputTensors> & inputTensors,std::shared_ptr<armnn::OutputTensors> & outputTensors,CallbackContext callbackContext)55 void RequestThread_1_3<PreparedModel, HalVersion, CallbackContext>::PostMsg(PreparedModel<HalVersion>* model,
56         std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& memPools,
57         std::shared_ptr<armnn::InputTensors>& inputTensors,
58         std::shared_ptr<armnn::OutputTensors>& outputTensors,
59         CallbackContext callbackContext)
60 {
61     ALOGV("RequestThread_1_3::PostMsg(...)");
62     auto data = std::make_shared<AsyncExecuteData>(model,
63                                                    memPools,
64                                                    inputTensors,
65                                                    outputTensors,
66                                                    callbackContext);
67     auto pMsg = std::make_shared<ThreadMsg>(ThreadMsgType::REQUEST, data);
68     PostMsg(pMsg, model->GetModelPriority());
69 }
70 
71 template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename CallbackContext>
PostMsg(std::shared_ptr<ThreadMsg> & pMsg,V1_3::Priority priority)72 void RequestThread_1_3<PreparedModel, HalVersion, CallbackContext>::PostMsg(std::shared_ptr<ThreadMsg>& pMsg,
73                                                                         V1_3::Priority priority)
74 {
75     ALOGV("RequestThread_1_3::PostMsg(pMsg)");
76     // Add a message to the queue and notify the request thread
77     std::unique_lock<std::mutex> lock(m_Mutex);
78     switch (priority) {
79         case V1_3::Priority::HIGH:
80             m_HighPriorityQueue.push(pMsg);
81             break;
82         case V1_3::Priority::LOW:
83             m_LowPriorityQueue.push(pMsg);
84             break;
85         case V1_3::Priority::MEDIUM:
86         default:
87             m_MediumPriorityQueue.push(pMsg);
88     }
89     m_Cv.notify_one();
90 }
91 
92 template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename CallbackContext>
Process()93 void RequestThread_1_3<PreparedModel, HalVersion, CallbackContext>::Process()
94 {
95     ALOGV("RequestThread_1_3::Process()");
96     int retireRate = RETIRE_RATE;
97     int highPriorityCount = 0;
98     int mediumPriorityCount = 0;
99     while (true)
100     {
101         std::shared_ptr<ThreadMsg> pMsg(nullptr);
102         {
103             // Wait for a message to be added to the queue
104             // This is in a separate scope to minimise the lifetime of the lock
105             std::unique_lock<std::mutex> lock(m_Mutex);
106             while (m_HighPriorityQueue.empty() && m_MediumPriorityQueue.empty() && m_LowPriorityQueue.empty())
107             {
108                 m_Cv.wait(lock);
109             }
110             // Get the message to process from the front of each queue based on priority from high to low
111             // Get high priority first if it does not exceed the retire rate
112             if (!m_HighPriorityQueue.empty() && highPriorityCount < retireRate)
113             {
114                 pMsg = m_HighPriorityQueue.front();
115                 m_HighPriorityQueue.pop();
116                 highPriorityCount += 1;
117             }
118             // If high priority queue is empty or the count exceeds the retire rate, get medium priority message
119             else if (!m_MediumPriorityQueue.empty() && mediumPriorityCount < retireRate)
120             {
121                 pMsg = m_MediumPriorityQueue.front();
122                 m_MediumPriorityQueue.pop();
123                 mediumPriorityCount += 1;
124                 // Reset high priority count
125                 highPriorityCount = 0;
126             }
127             // If medium priority queue is empty or the count exceeds the retire rate, get low priority message
128             else if (!m_LowPriorityQueue.empty())
129             {
130                 pMsg = m_LowPriorityQueue.front();
131                 m_LowPriorityQueue.pop();
132                 // Reset high and medium priority count
133                 highPriorityCount = 0;
134                 mediumPriorityCount = 0;
135             }
136             else
137             {
138                 // Reset high and medium priority count
139                 highPriorityCount = 0;
140                 mediumPriorityCount = 0;
141                 continue;
142             }
143         }
144 
145         switch (pMsg->type)
146         {
147             case ThreadMsgType::REQUEST:
148             {
149                 ALOGV("RequestThread_1_3::Process() - request");
150                 // invoke the asynchronous execution method
151                 PreparedModel<HalVersion>* model = pMsg->data->m_Model;
152                 model->ExecuteGraph(pMsg->data->m_MemPools,
153                                     *(pMsg->data->m_InputTensors),
154                                     *(pMsg->data->m_OutputTensors),
155                                     pMsg->data->m_CallbackContext);
156                 break;
157             }
158 
159             case ThreadMsgType::EXIT:
160             {
161                 ALOGV("RequestThread_1_3::Process() - exit");
162                 // delete all remaining messages (there should not be any)
163                 std::unique_lock<std::mutex> lock(m_Mutex);
164                 while (!m_HighPriorityQueue.empty())
165                 {
166                     m_HighPriorityQueue.pop();
167                 }
168                 while (!m_MediumPriorityQueue.empty())
169                 {
170                     m_MediumPriorityQueue.pop();
171                 }
172                 while (!m_LowPriorityQueue.empty())
173                 {
174                     m_LowPriorityQueue.pop();
175                 }
176                 return;
177             }
178 
179             default:
180                 // this should be unreachable
181                 ALOGE("RequestThread_1_3::Process() - invalid message type");
182                 ARMNN_ASSERT_MSG(false, "ArmNN: RequestThread_1_3: invalid message type");
183         }
184     }
185 }
186 
187 ///
188 /// Class template specializations
189 ///
190 
191 template class RequestThread_1_3<ArmnnPreparedModel_1_3, hal_1_3::HalPolicy, CallbackContext_1_3>;
192 
193 } // namespace armnn_driver
194