• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_ML_NN_COMMON_CPU_EXECUTOR_H
18 #define ANDROID_ML_NN_COMMON_CPU_EXECUTOR_H
19 
20 #include "HalInterfaces.h"
21 #include "OperationsUtils.h"
22 #include "Utils.h"
23 
24 #include <algorithm>
25 #include <vector>
26 
27 namespace android {
28 namespace nn {
29 
30 // Information we maintain about each operand during execution that
31 // may change during execution.
32 struct RunTimeOperandInfo {
33     // TODO Storing the type here is redundant, as it won't change during execution.
34     OperandType type;
35     // The type and dimensions of the operand.  The dimensions can
36     // change at runtime.  We include the type because it's useful
37     // to pass together with the dimension to the functions implementing
38     // the operators.
39     std::vector<uint32_t> dimensions;
40 
41     float scale;
42     int32_t zeroPoint;
43     // Where the operand's data is stored.  Check the corresponding
44     // location information in the model to figure out if this points
45     // to memory we have allocated for an temporary operand.
46     uint8_t* buffer;
47     // The length of the buffer.
48     uint32_t length;
49     // Whether this is a temporary variable, a model input, a constant, etc.
50     OperandLifeTime lifetime;
51     // Keeps track of how many operations have yet to make use
52     // of this temporary variable.  When the count is decremented to 0,
53     // we free the buffer.  For non-temporary variables, this count is
54     // always 0.
55     uint32_t numberOfUsesLeft;
56 
shapeRunTimeOperandInfo57     Shape shape() const {
58         return Shape{.type = type, .dimensions = dimensions, .scale = scale, .offset = zeroPoint};
59     }
60 };
61 
62 // Used to keep a pointer to each of the memory pools.
63 struct RunTimePoolInfo {
64     sp<IMemory> memory;
65     hidl_memory hidlMemory;
66     uint8_t* buffer;
67 
68     bool set(const hidl_memory& hidlMemory);
69     bool update();
70 };
71 
72 bool setRunTimePoolInfosFromHidlMemories(std::vector<RunTimePoolInfo>* poolInfos,
73                                          const hidl_vec<hidl_memory>& pools);
74 
75 // This class is used to execute a model on the CPU.
76 class CpuExecutor {
77 public:
78     // Executes the model. The results will be stored at the locations
79     // specified in the constructor.
80     // The model must outlive the executor.  We prevent it from being modified
81     // while this is executing.
82     int run(const Model& model, const Request& request,
83             const std::vector<RunTimePoolInfo>& modelPoolInfos,
84             const std::vector<RunTimePoolInfo>& requestPoolInfos);
85 
86 private:
87     bool initializeRunTimeInfo(const std::vector<RunTimePoolInfo>& modelPoolInfos,
88                                const std::vector<RunTimePoolInfo>& requestPoolInfos);
89     // Runs one operation of the graph.
90     int executeOperation(const Operation& entry);
91     // Decrement the usage count for the operands listed.  Frees the memory
92     // allocated for any temporary variable with a count of zero.
93     void freeNoLongerUsedOperands(const std::vector<uint32_t>& inputs);
94 
95     // The model and the request that we'll execute. Only valid while run()
96     // is being executed.
97     const Model* mModel = nullptr;
98     const Request* mRequest = nullptr;
99 
100     // We're copying the list of all the dimensions from the model, as
101     // these may be modified when we run the operatins.  Since we're
102     // making a full copy, the indexes used in the operand description
103     // stay valid.
104     //    std::vector<uint32_t> mDimensions;
105     // Runtime information about all the operands.
106     std::vector<RunTimeOperandInfo> mOperands;
107 };
108 
109 namespace {
110 
111 template <typename T>
getScalarData(const RunTimeOperandInfo & info)112 T getScalarData(const RunTimeOperandInfo& info) {
113   // TODO: Check buffer is at least as long as size of data.
114   T* data = reinterpret_cast<T*>(info.buffer);
115   return data[0];
116 }
117 
IsNullInput(const RunTimeOperandInfo * input)118 inline bool IsNullInput(const RunTimeOperandInfo *input) {
119     return input->lifetime == OperandLifeTime::NO_VALUE;
120 }
121 
NumInputsWithValues(const Operation & operation,std::vector<RunTimeOperandInfo> & operands)122 inline int NumInputsWithValues(const Operation &operation,
123                                std::vector<RunTimeOperandInfo> &operands) {
124   const std::vector<uint32_t> &inputs = operation.inputs;
125   return std::count_if(inputs.begin(), inputs.end(),
126                        [&operands](uint32_t i) {
127                          return !IsNullInput(&operands[i]);
128                        });
129 }
130 
NumOutputs(const Operation & operation)131 inline int NumOutputs(const Operation &operation) {
132   return operation.outputs.size();
133 }
134 
NumDimensions(const RunTimeOperandInfo * operand)135 inline size_t NumDimensions(const RunTimeOperandInfo *operand) {
136   return operand->shape().dimensions.size();
137 }
138 
SizeOfDimension(const RunTimeOperandInfo * operand,int i)139 inline uint32_t SizeOfDimension(const RunTimeOperandInfo *operand, int i) {
140   return operand->shape().dimensions[i];
141 }
142 
GetInput(const Operation & operation,std::vector<RunTimeOperandInfo> & operands,int index)143 inline RunTimeOperandInfo *GetInput(const Operation &operation,
144                                     std::vector<RunTimeOperandInfo> &operands,
145                                     int index) {
146   return &operands[operation.inputs[index]];
147 }
148 
GetOutput(const Operation & operation,std::vector<RunTimeOperandInfo> & operands,int index)149 inline RunTimeOperandInfo *GetOutput(const Operation &operation,
150                                      std::vector<RunTimeOperandInfo> &operands,
151                                      int index) {
152   return &operands[operation.outputs[index]];
153 }
154 
155 }  // anonymous namespace
156 
157 } // namespace nn
158 } // namespace android
159 
160 #endif // ANDROID_ML_NN_COMMON_CPU_EXECUTOR_H
161