• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "ArmnnDriver.hpp"
9 #include "ArmnnDriverImpl.hpp"
10 #include "RequestThread.hpp"
11 
12 #include <NeuralNetworks.h>
13 #include <armnn/ArmNN.hpp>
14 
15 #include <string>
16 #include <vector>
17 
18 namespace armnn_driver
19 {
20 using armnnExecuteCallback_1_0 = std::function<void(V1_0::ErrorStatus status, std::string callingFunction)>;
21 
22 struct ArmnnCallback_1_0
23 {
24     armnnExecuteCallback_1_0 callback;
25 };
26 
27 struct ExecutionContext_1_0 {};
28 
29 using CallbackContext_1_0 = CallbackContext<armnnExecuteCallback_1_0, ExecutionContext_1_0>;
30 
31 template <typename HalVersion>
32 class ArmnnPreparedModel : public V1_0::IPreparedModel
33 {
34 public:
35     using HalModel = typename HalVersion::Model;
36 
37     ArmnnPreparedModel(armnn::NetworkId networkId,
38                        armnn::IRuntime* runtime,
39                        const HalModel& model,
40                        const std::string& requestInputsAndOutputsDumpDir,
41                        const bool gpuProfilingEnabled);
42 
43     virtual ~ArmnnPreparedModel();
44 
45     virtual Return<V1_0::ErrorStatus> execute(const V1_0::Request& request,
46                                               const ::android::sp<V1_0::IExecutionCallback>& callback) override;
47 
48     /// execute the graph prepared from the request
49     void ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
50                       armnn::InputTensors& inputTensors,
51                       armnn::OutputTensors& outputTensors,
52                       CallbackContext_1_0 callback);
53 
54     /// Executes this model with dummy inputs (e.g. all zeroes).
55     /// \return false on failure, otherwise true
56     bool ExecuteWithDummyInputs();
57 
58 private:
59     template <typename TensorBindingCollection>
60     void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings);
61 
62     armnn::NetworkId                                                        m_NetworkId;
63     armnn::IRuntime*                                                        m_Runtime;
64     HalModel                                                                m_Model;
65     // There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads
66     // It is specific to this class, so it is declared as static here
67     static RequestThread<ArmnnPreparedModel, HalVersion, CallbackContext_1_0> m_RequestThread;
68     uint32_t                                                                m_RequestCount;
69     const std::string&                                                      m_RequestInputsAndOutputsDumpDir;
70     const bool                                                              m_GpuProfilingEnabled;
71 };
72 
73 }
74