• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #define LOG_TAG "ArmnnDriver"
7 
8 #include "ArmnnDriverImpl.hpp"
9 #include "ArmnnPreparedModel.hpp"
10 
11 #if defined(ARMNN_ANDROID_NN_V1_2) || defined(ARMNN_ANDROID_NN_V1_3) // Using ::android::hardware::neuralnetworks::V1_2
12 #include "ArmnnPreparedModel_1_2.hpp"
13 #endif
14 
15 #ifdef ARMNN_ANDROID_NN_V1_3 // Using ::android::hardware::neuralnetworks::V1_2
16 #include "ArmnnPreparedModel_1_3.hpp"
17 #endif
18 
19 #include "Utils.hpp"
20 
21 #include "ModelToINetworkConverter.hpp"
22 #include "SystemPropertiesUtils.hpp"
23 
24 #include <ValidateHal.h>
25 #include <log/log.h>
26 
27 using namespace std;
28 using namespace android;
29 using namespace android::nn;
30 using namespace android::hardware;
31 
32 namespace
33 {
34 
NotifyCallbackAndCheck(const sp<V1_0::IPreparedModelCallback> & callback,V1_0::ErrorStatus errorStatus,const sp<V1_0::IPreparedModel> & preparedModelPtr)35 void NotifyCallbackAndCheck(const sp<V1_0::IPreparedModelCallback>& callback,
36                             V1_0::ErrorStatus errorStatus,
37                             const sp<V1_0::IPreparedModel>& preparedModelPtr)
38 {
39     Return<void> returned = callback->notify(errorStatus, preparedModelPtr);
40     // This check is required, if the callback fails and it isn't checked it will bring down the service
41     if (!returned.isOk())
42     {
43         ALOGE("ArmnnDriverImpl::prepareModel: hidl callback failed to return properly: %s ",
44               returned.description().c_str());
45     }
46 }
47 
FailPrepareModel(V1_0::ErrorStatus error,const string & message,const sp<V1_0::IPreparedModelCallback> & callback)48 Return<V1_0::ErrorStatus> FailPrepareModel(V1_0::ErrorStatus error,
49                                            const string& message,
50                                            const sp<V1_0::IPreparedModelCallback>& callback)
51 {
52     ALOGW("ArmnnDriverImpl::prepareModel: %s", message.c_str());
53     NotifyCallbackAndCheck(callback, error, nullptr);
54     return error;
55 }
56 
57 } // namespace
58 
59 namespace armnn_driver
60 {
61 
62 template<typename HalPolicy>
prepareModel(const armnn::IRuntimePtr & runtime,const armnn::IGpuAccTunedParametersPtr & clTunedParameters,const DriverOptions & options,const HalModel & model,const sp<V1_0::IPreparedModelCallback> & cb,bool float32ToFloat16)63 Return<V1_0::ErrorStatus> ArmnnDriverImpl<HalPolicy>::prepareModel(
64         const armnn::IRuntimePtr& runtime,
65         const armnn::IGpuAccTunedParametersPtr& clTunedParameters,
66         const DriverOptions& options,
67         const HalModel& model,
68         const sp<V1_0::IPreparedModelCallback>& cb,
69         bool float32ToFloat16)
70 {
71     ALOGV("ArmnnDriverImpl::prepareModel()");
72 
73     if (cb.get() == nullptr)
74     {
75         ALOGW("ArmnnDriverImpl::prepareModel: Invalid callback passed to prepareModel");
76         return V1_0::ErrorStatus::INVALID_ARGUMENT;
77     }
78 
79     if (!runtime)
80     {
81         return FailPrepareModel(V1_0::ErrorStatus::DEVICE_UNAVAILABLE, "Device unavailable", cb);
82     }
83 
84     if (!android::nn::validateModel(model))
85     {
86         return FailPrepareModel(V1_0::ErrorStatus::INVALID_ARGUMENT, "Invalid model passed as input", cb);
87     }
88 
89     // Deliberately ignore any unsupported operations requested by the options -
90     // at this point we're being asked to prepare a model that we've already declared support for
91     // and the operation indices may be different to those in getSupportedOperations anyway.
92     set<unsigned int> unsupportedOperations;
93     ModelToINetworkConverter<HalPolicy> modelConverter(options.GetBackends(),
94                                                        model,
95                                                        unsupportedOperations);
96 
97     if (modelConverter.GetConversionResult() != ConversionResult::Success)
98     {
99         FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "ModelToINetworkConverter failed", cb);
100         return V1_0::ErrorStatus::NONE;
101     }
102 
103     // Optimize the network
104     armnn::IOptimizedNetworkPtr optNet(nullptr, nullptr);
105     armnn::OptimizerOptions OptOptions;
106     OptOptions.m_ReduceFp32ToFp16 = float32ToFloat16;
107 
108     armnn::BackendOptions gpuAcc("GpuAcc",
109     {
110         { "FastMathEnabled", options.IsFastMathEnabled() }
111     });
112     armnn::BackendOptions cpuAcc("CpuAcc",
113     {
114         { "FastMathEnabled", options.IsFastMathEnabled() }
115     });
116     OptOptions.m_ModelOptions.push_back(gpuAcc);
117     OptOptions.m_ModelOptions.push_back(cpuAcc);
118 
119     std::vector<std::string> errMessages;
120     try
121     {
122         optNet = armnn::Optimize(*modelConverter.GetINetwork(),
123                                  options.GetBackends(),
124                                  runtime->GetDeviceSpec(),
125                                  OptOptions,
126                                  errMessages);
127     }
128     catch (std::exception &e)
129     {
130         stringstream message;
131         message << "Exception (" << e.what() << ") caught from optimize.";
132         FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, message.str(), cb);
133         return V1_0::ErrorStatus::NONE;
134     }
135 
136     // Check that the optimized network is valid.
137     if (!optNet)
138     {
139         stringstream message;
140         message << "Invalid optimized network";
141         for (const string& msg : errMessages)
142         {
143             message << "\n" << msg;
144         }
145         FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, message.str(), cb);
146         return V1_0::ErrorStatus::NONE;
147     }
148 
149     // Export the optimized network graph to a dot file if an output dump directory
150     // has been specified in the drivers' arguments.
151     std::string dotGraphFileName = ExportNetworkGraphToDotFile(*optNet, options.GetRequestInputsAndOutputsDumpDir());
152 
153     // Load it into the runtime.
154     armnn::NetworkId netId = 0;
155     try
156     {
157         if (runtime->LoadNetwork(netId, move(optNet)) != armnn::Status::Success)
158         {
159             return FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "Network could not be loaded", cb);
160         }
161     }
162     catch (std::exception& e)
163     {
164         stringstream message;
165         message << "Exception (" << e.what()<< ") caught from LoadNetwork.";
166         FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, message.str(), cb);
167         return V1_0::ErrorStatus::NONE;
168     }
169 
170     // Now that we have a networkId for the graph rename the dump file to use it
171     // so that we can associate the graph file and the input/output tensor dump files
172     RenameGraphDotFile(dotGraphFileName,
173                        options.GetRequestInputsAndOutputsDumpDir(),
174                        netId);
175 
176     sp<ArmnnPreparedModel<HalPolicy>> preparedModel(
177             new ArmnnPreparedModel<HalPolicy>(
178                     netId,
179                     runtime.get(),
180                     model,
181                     options.GetRequestInputsAndOutputsDumpDir(),
182                     options.IsGpuProfilingEnabled()));
183 
184     // Run a single 'dummy' inference of the model. This means that CL kernels will get compiled (and tuned if
185     // this is enabled) before the first 'real' inference which removes the overhead of the first inference.
186     if (!preparedModel->ExecuteWithDummyInputs())
187     {
188         return FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "Network could not be executed", cb);
189     }
190 
191     if (clTunedParameters &&
192         options.GetClTunedParametersMode() == armnn::IGpuAccTunedParameters::Mode::UpdateTunedParameters)
193     {
194         // Now that we've done one inference the CL kernel parameters will have been tuned, so save the updated file.
195         try
196         {
197             clTunedParameters->Save(options.GetClTunedParametersFile().c_str());
198         }
199         catch (std::exception& error)
200         {
201             ALOGE("ArmnnDriverImpl::prepareModel: Failed to save CL tuned parameters file '%s': %s",
202                   options.GetClTunedParametersFile().c_str(), error.what());
203         }
204     }
205 
206     NotifyCallbackAndCheck(cb, V1_0::ErrorStatus::NONE, preparedModel);
207 
208     return V1_0::ErrorStatus::NONE;
209 }
210 
211 template<typename HalPolicy>
getSupportedOperations(const armnn::IRuntimePtr & runtime,const DriverOptions & options,const HalModel & model,HalGetSupportedOperations_cb cb)212 Return<void> ArmnnDriverImpl<HalPolicy>::getSupportedOperations(const armnn::IRuntimePtr& runtime,
213                                                                 const DriverOptions& options,
214                                                                 const HalModel& model,
215                                                                 HalGetSupportedOperations_cb cb)
216 {
217     std::stringstream ss;
218     ss << "ArmnnDriverImpl::getSupportedOperations()";
219     std::string fileName;
220     std::string timestamp;
221     if (!options.GetRequestInputsAndOutputsDumpDir().empty())
222     {
223         ss << " : "
224            << options.GetRequestInputsAndOutputsDumpDir()
225            << "/"
226            << GetFileTimestamp()
227            << "_getSupportedOperations.txt";
228     }
229     ALOGV(ss.str().c_str());
230 
231     if (!options.GetRequestInputsAndOutputsDumpDir().empty())
232     {
233         //dump the marker file
234         std::ofstream fileStream;
235         fileStream.open(fileName, std::ofstream::out | std::ofstream::trunc);
236         if (fileStream.good())
237         {
238             fileStream << timestamp << std::endl;
239         }
240         fileStream.close();
241     }
242 
243     vector<bool> result;
244 
245     if (!runtime)
246     {
247         cb(HalErrorStatus::DEVICE_UNAVAILABLE, result);
248         return Void();
249     }
250 
251     // Run general model validation, if this doesn't pass we shouldn't analyse the model anyway.
252     if (!android::nn::validateModel(model))
253     {
254         cb(HalErrorStatus::INVALID_ARGUMENT, result);
255         return Void();
256     }
257 
258     // Attempt to convert the model to an ArmNN input network (INetwork).
259     ModelToINetworkConverter<HalPolicy> modelConverter(options.GetBackends(),
260                                                        model,
261                                                        options.GetForcedUnsupportedOperations());
262 
263     if (modelConverter.GetConversionResult() != ConversionResult::Success
264             && modelConverter.GetConversionResult() != ConversionResult::UnsupportedFeature)
265     {
266         cb(HalErrorStatus::GENERAL_FAILURE, result);
267         return Void();
268     }
269 
270     // Check each operation if it was converted successfully and copy the flags
271     // into the result (vector<bool>) that we need to return to Android.
272     result.reserve(getMainModel(model).operations.size());
273     for (uint32_t operationIdx = 0;
274          operationIdx < getMainModel(model).operations.size();
275          ++operationIdx)
276     {
277         bool operationSupported = modelConverter.IsOperationSupported(operationIdx);
278         result.push_back(operationSupported);
279     }
280 
281     cb(HalErrorStatus::NONE, result);
282     return Void();
283 }
284 
285 template<typename HalPolicy>
getStatus()286 Return<V1_0::DeviceStatus> ArmnnDriverImpl<HalPolicy>::getStatus()
287 {
288     ALOGV("ArmnnDriver::getStatus()");
289 
290     return V1_0::DeviceStatus::AVAILABLE;
291 }
292 
293 ///
294 /// Class template specializations
295 ///
296 
297 template class ArmnnDriverImpl<hal_1_0::HalPolicy>;
298 
299 #ifdef ARMNN_ANDROID_NN_V1_1
300 template class ArmnnDriverImpl<hal_1_1::HalPolicy>;
301 #endif
302 
303 #ifdef ARMNN_ANDROID_NN_V1_2
304 template class ArmnnDriverImpl<hal_1_1::HalPolicy>;
305 template class ArmnnDriverImpl<hal_1_2::HalPolicy>;
306 #endif
307 
308 #ifdef ARMNN_ANDROID_NN_V1_3
309 template class ArmnnDriverImpl<hal_1_1::HalPolicy>;
310 template class ArmnnDriverImpl<hal_1_2::HalPolicy>;
311 template class ArmnnDriverImpl<hal_1_3::HalPolicy>;
312 #endif
313 
314 } // namespace armnn_driver
315