1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include "ArmnnDriver.hpp" 9 #include "ArmnnDriverImpl.hpp" 10 #include "RequestThread.hpp" 11 #include "ModelToINetworkConverter.hpp" 12 13 #include <NeuralNetworks.h> 14 #include <armnn/ArmNN.hpp> 15 16 #include <string> 17 #include <vector> 18 19 namespace armnn_driver 20 { 21 22 using CallbackAsync_1_2 = std::function< 23 void(V1_0::ErrorStatus errorStatus, 24 std::vector<::android::hardware::neuralnetworks::V1_2::OutputShape> outputShapes, 25 const ::android::hardware::neuralnetworks::V1_2::Timing& timing, 26 std::string callingFunction)>; 27 28 struct ExecutionContext_1_2 29 { 30 ::android::hardware::neuralnetworks::V1_2::MeasureTiming measureTimings = 31 ::android::hardware::neuralnetworks::V1_2::MeasureTiming::NO; 32 TimePoint driverStart; 33 }; 34 35 using CallbackContext_1_2 = CallbackContext<CallbackAsync_1_2, ExecutionContext_1_2>; 36 37 template <typename HalVersion> 38 class ArmnnPreparedModel_1_2 : public V1_2::IPreparedModel 39 { 40 public: 41 using HalModel = typename V1_2::Model; 42 43 ArmnnPreparedModel_1_2(armnn::NetworkId networkId, 44 armnn::IRuntime* runtime, 45 const HalModel& model, 46 const std::string& requestInputsAndOutputsDumpDir, 47 const bool gpuProfilingEnabled); 48 49 virtual ~ArmnnPreparedModel_1_2(); 50 51 virtual Return<V1_0::ErrorStatus> execute(const V1_0::Request& request, 52 const sp<V1_0::IExecutionCallback>& callback) override; 53 54 virtual Return<V1_0::ErrorStatus> execute_1_2(const V1_0::Request& request, V1_2::MeasureTiming measure, 55 const sp<V1_2::IExecutionCallback>& callback) override; 56 57 virtual Return<void> executeSynchronously(const V1_0::Request &request, 58 V1_2::MeasureTiming measure, 59 V1_2::IPreparedModel::executeSynchronously_cb cb) override; 60 61 virtual Return<void> configureExecutionBurst( 62 const sp<V1_2::IBurstCallback>& callback, 63 const android::hardware::MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel, 64 const android::hardware::MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel, 65 configureExecutionBurst_cb cb) override; 66 67 /// execute the graph prepared from the request 68 template<typename CallbackContext> 69 bool ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools, 70 armnn::InputTensors& inputTensors, 71 armnn::OutputTensors& outputTensors, 72 CallbackContext callback); 73 74 /// Executes this model with dummy inputs (e.g. all zeroes). 75 /// \return false on failure, otherwise true 76 bool ExecuteWithDummyInputs(); 77 78 private: 79 Return<V1_0::ErrorStatus> Execute(const V1_0::Request& request, 80 V1_2::MeasureTiming measureTiming, 81 CallbackAsync_1_2 callback); 82 83 Return<V1_0::ErrorStatus> PrepareMemoryForInputs( 84 armnn::InputTensors& inputs, 85 const V1_0::Request& request, 86 const std::vector<android::nn::RunTimePoolInfo>& memPools); 87 88 Return<V1_0::ErrorStatus> PrepareMemoryForOutputs( 89 armnn::OutputTensors& outputs, 90 std::vector<V1_2::OutputShape> &outputShapes, 91 const V1_0::Request& request, 92 const std::vector<android::nn::RunTimePoolInfo>& memPools); 93 94 Return <V1_0::ErrorStatus> PrepareMemoryForIO( 95 armnn::InputTensors& inputs, 96 armnn::OutputTensors& outputs, 97 std::vector<android::nn::RunTimePoolInfo>& memPools, 98 const V1_0::Request& request, 99 CallbackAsync_1_2 callback); 100 101 template <typename TensorBindingCollection> 102 void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings); 103 104 armnn::NetworkId m_NetworkId; 105 armnn::IRuntime* m_Runtime; 106 V1_2::Model m_Model; 107 // There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads 108 // It is specific to this class, so it is declared as static here 109 static RequestThread<ArmnnPreparedModel_1_2, 110 HalVersion, 111 CallbackContext_1_2> m_RequestThread; 112 uint32_t m_RequestCount; 113 const std::string& m_RequestInputsAndOutputsDumpDir; 114 const bool m_GpuProfilingEnabled; 115 }; 116 117 } 118