1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include "ArmnnDriver.hpp" 9 #include "ArmnnDriverImpl.hpp" 10 #include "RequestThread.hpp" 11 12 #include <NeuralNetworks.h> 13 #include <armnn/ArmNN.hpp> 14 15 #include <string> 16 #include <vector> 17 18 namespace armnn_driver 19 { 20 using armnnExecuteCallback_1_0 = std::function<void(V1_0::ErrorStatus status, std::string callingFunction)>; 21 22 struct ArmnnCallback_1_0 23 { 24 armnnExecuteCallback_1_0 callback; 25 }; 26 27 struct ExecutionContext_1_0 {}; 28 29 using CallbackContext_1_0 = CallbackContext<armnnExecuteCallback_1_0, ExecutionContext_1_0>; 30 31 template <typename HalVersion> 32 class ArmnnPreparedModel : public V1_0::IPreparedModel 33 { 34 public: 35 using HalModel = typename HalVersion::Model; 36 37 ArmnnPreparedModel(armnn::NetworkId networkId, 38 armnn::IRuntime* runtime, 39 const HalModel& model, 40 const std::string& requestInputsAndOutputsDumpDir, 41 const bool gpuProfilingEnabled); 42 43 virtual ~ArmnnPreparedModel(); 44 45 virtual Return<V1_0::ErrorStatus> execute(const V1_0::Request& request, 46 const ::android::sp<V1_0::IExecutionCallback>& callback) override; 47 48 /// execute the graph prepared from the request 49 void ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools, 50 armnn::InputTensors& inputTensors, 51 armnn::OutputTensors& outputTensors, 52 CallbackContext_1_0 callback); 53 54 /// Executes this model with dummy inputs (e.g. all zeroes). 55 /// \return false on failure, otherwise true 56 bool ExecuteWithDummyInputs(); 57 58 private: 59 template <typename TensorBindingCollection> 60 void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings); 61 62 armnn::NetworkId m_NetworkId; 63 armnn::IRuntime* m_Runtime; 64 HalModel m_Model; 65 // There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads 66 // It is specific to this class, so it is declared as static here 67 static RequestThread<ArmnnPreparedModel, HalVersion, CallbackContext_1_0> m_RequestThread; 68 uint32_t m_RequestCount; 69 const std::string& m_RequestInputsAndOutputsDumpDir; 70 const bool m_GpuProfilingEnabled; 71 }; 72 73 } 74