• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #define LOG_TAG "ArmnnDriver"
7 
8 #include "ArmnnPreparedModel_1_3.hpp"
9 #include "RequestThread_1_3.hpp"
10 
11 #include <log/log.h>
12 
13 using namespace android;
14 
15 namespace armnn_driver
16 {
17 
18 template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename CallbackContext>
RequestThread_1_3()19 RequestThread_1_3<PreparedModel, HalVersion, CallbackContext>::RequestThread_1_3()
20 {
21     ALOGV("RequestThread_1_3::RequestThread_1_3()");
22     m_Thread = std::make_unique<std::thread>(&RequestThread_1_3::Process, this);
23 }
24 
25 template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename CallbackContext>
~RequestThread_1_3()26 RequestThread_1_3<PreparedModel, HalVersion, CallbackContext>::~RequestThread_1_3()
27 {
28     ALOGV("RequestThread_1_3::~RequestThread_1_3()");
29 
30     try
31     {
32         // Coverity fix: The following code may throw an exception of type std::length_error.
33 
34         // This code is meant to to terminate the inner thread gracefully by posting an EXIT message
35         // to the thread's message queue. However, according to Coverity, this code could throw an exception and fail.
36         // Since only one static instance of RequestThread is used in the driver (in ArmnnPreparedModel),
37         // this destructor is called only when the application has been closed, which means that
38         // the inner thread will be terminated anyway, although abruptly, in the event that the destructor code throws.
39         // Wrapping the destructor's code with a try-catch block simply fixes the Coverity bug.
40 
41         // Post an EXIT message to the thread
42         std::shared_ptr<AsyncExecuteData> nulldata(nullptr);
43         auto pMsg = std::make_shared<ThreadMsg>(ThreadMsgType::EXIT, nulldata);
44         PostMsg(pMsg);
45         // Wait for the thread to terminate, it is deleted automatically
46         m_Thread->join();
47     }
48     catch (const std::exception&) { } // Swallow any exception.
49 }
50 
51 template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename CallbackContext>
PostMsg(PreparedModel<HalVersion> * model,std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> & memPools,std::shared_ptr<armnn::InputTensors> & inputTensors,std::shared_ptr<armnn::OutputTensors> & outputTensors,CallbackContext callbackContext)52 void RequestThread_1_3<PreparedModel, HalVersion, CallbackContext>::PostMsg(PreparedModel<HalVersion>* model,
53         std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& memPools,
54         std::shared_ptr<armnn::InputTensors>& inputTensors,
55         std::shared_ptr<armnn::OutputTensors>& outputTensors,
56         CallbackContext callbackContext)
57 {
58     ALOGV("RequestThread_1_3::PostMsg(...)");
59     auto data = std::make_shared<AsyncExecuteData>(model,
60                                                    memPools,
61                                                    inputTensors,
62                                                    outputTensors,
63                                                    callbackContext);
64     auto pMsg = std::make_shared<ThreadMsg>(ThreadMsgType::REQUEST, data);
65     PostMsg(pMsg, model->GetModelPriority());
66 }
67 
68 template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename CallbackContext>
PostMsg(std::shared_ptr<ThreadMsg> & pMsg,V1_3::Priority priority)69 void RequestThread_1_3<PreparedModel, HalVersion, CallbackContext>::PostMsg(std::shared_ptr<ThreadMsg>& pMsg,
70                                                                         V1_3::Priority priority)
71 {
72     ALOGV("RequestThread_1_3::PostMsg(pMsg)");
73     // Add a message to the queue and notify the request thread
74     std::unique_lock<std::mutex> lock(m_Mutex);
75     switch (priority) {
76         case V1_3::Priority::HIGH:
77             m_HighPriorityQueue.push(pMsg);
78             break;
79         case V1_3::Priority::LOW:
80             m_LowPriorityQueue.push(pMsg);
81             break;
82         case V1_3::Priority::MEDIUM:
83         default:
84             m_MediumPriorityQueue.push(pMsg);
85     }
86     m_Cv.notify_one();
87 }
88 
89 template <template <typename HalVersion> class PreparedModel, typename HalVersion, typename CallbackContext>
Process()90 void RequestThread_1_3<PreparedModel, HalVersion, CallbackContext>::Process()
91 {
92     ALOGV("RequestThread_1_3::Process()");
93     int retireRate = RETIRE_RATE;
94     int highPriorityCount = 0;
95     int mediumPriorityCount = 0;
96     while (true)
97     {
98         std::shared_ptr<ThreadMsg> pMsg(nullptr);
99         {
100             // Wait for a message to be added to the queue
101             // This is in a separate scope to minimise the lifetime of the lock
102             std::unique_lock<std::mutex> lock(m_Mutex);
103             while (m_HighPriorityQueue.empty() && m_MediumPriorityQueue.empty() && m_LowPriorityQueue.empty())
104             {
105                 m_Cv.wait(lock);
106             }
107             // Get the message to process from the front of each queue based on priority from high to low
108             // Get high priority first if it does not exceed the retire rate
109             if (!m_HighPriorityQueue.empty() && highPriorityCount < retireRate)
110             {
111                 pMsg = m_HighPriorityQueue.front();
112                 m_HighPriorityQueue.pop();
113                 highPriorityCount += 1;
114             }
115             // If high priority queue is empty or the count exceeds the retire rate, get medium priority message
116             else if (!m_MediumPriorityQueue.empty() && mediumPriorityCount < retireRate)
117             {
118                 pMsg = m_MediumPriorityQueue.front();
119                 m_MediumPriorityQueue.pop();
120                 mediumPriorityCount += 1;
121                 // Reset high priority count
122                 highPriorityCount = 0;
123             }
124             // If medium priority queue is empty or the count exceeds the retire rate, get low priority message
125             else if (!m_LowPriorityQueue.empty())
126             {
127                 pMsg = m_LowPriorityQueue.front();
128                 m_LowPriorityQueue.pop();
129                 // Reset high and medium priority count
130                 highPriorityCount = 0;
131                 mediumPriorityCount = 0;
132             }
133             else
134             {
135                 // Reset high and medium priority count
136                 highPriorityCount = 0;
137                 mediumPriorityCount = 0;
138                 continue;
139             }
140         }
141 
142         switch (pMsg->type)
143         {
144             case ThreadMsgType::REQUEST:
145             {
146                 ALOGV("RequestThread_1_3::Process() - request");
147                 // invoke the asynchronous execution method
148                 PreparedModel<HalVersion>* model = pMsg->data->m_Model;
149                 model->ExecuteGraph(pMsg->data->m_MemPools,
150                                     *(pMsg->data->m_InputTensors),
151                                     *(pMsg->data->m_OutputTensors),
152                                     pMsg->data->m_CallbackContext);
153                 break;
154             }
155 
156             case ThreadMsgType::EXIT:
157             {
158                 ALOGV("RequestThread_1_3::Process() - exit");
159                 // delete all remaining messages (there should not be any)
160                 std::unique_lock<std::mutex> lock(m_Mutex);
161                 while (!m_HighPriorityQueue.empty())
162                 {
163                     m_HighPriorityQueue.pop();
164                 }
165                 while (!m_MediumPriorityQueue.empty())
166                 {
167                     m_MediumPriorityQueue.pop();
168                 }
169                 while (!m_LowPriorityQueue.empty())
170                 {
171                     m_LowPriorityQueue.pop();
172                 }
173                 return;
174             }
175 
176             default:
177                 // this should be unreachable
178                 throw armnn::RuntimeException("ArmNN: RequestThread_1_3: invalid message type");
179         }
180     }
181 }
182 
183 ///
184 /// Class template specializations
185 ///
186 
187 template class RequestThread_1_3<ArmnnPreparedModel_1_3, hal_1_3::HalPolicy, CallbackContext_1_3>;
188 
189 } // namespace armnn_driver
190