1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef ANDROID_ML_NN_COMMON_CPU_EXECUTOR_H
18 #define ANDROID_ML_NN_COMMON_CPU_EXECUTOR_H
19
20 #include "HalInterfaces.h"
21 #include "OperationsUtils.h"
22 #include "Utils.h"
23
24 #include <algorithm>
25 #include <vector>
26
27 namespace android {
28 namespace nn {
29
30 // Information we maintain about each operand during execution that
31 // may change during execution.
32 struct RunTimeOperandInfo {
33 // TODO Storing the type here is redundant, as it won't change during execution.
34 OperandType type;
35 // The type and dimensions of the operand. The dimensions can
36 // change at runtime. We include the type because it's useful
37 // to pass together with the dimension to the functions implementing
38 // the operators.
39 std::vector<uint32_t> dimensions;
40
41 float scale;
42 int32_t zeroPoint;
43 // Where the operand's data is stored. Check the corresponding
44 // location information in the model to figure out if this points
45 // to memory we have allocated for an temporary operand.
46 uint8_t* buffer;
47 // The length of the buffer.
48 uint32_t length;
49 // Whether this is a temporary variable, a model input, a constant, etc.
50 OperandLifeTime lifetime;
51 // Keeps track of how many operations have yet to make use
52 // of this temporary variable. When the count is decremented to 0,
53 // we free the buffer. For non-temporary variables, this count is
54 // always 0.
55 uint32_t numberOfUsesLeft;
56
shapeRunTimeOperandInfo57 Shape shape() const {
58 return Shape{.type = type, .dimensions = dimensions, .scale = scale, .offset = zeroPoint};
59 }
60 };
61
62 // Used to keep a pointer to each of the memory pools.
63 struct RunTimePoolInfo {
64 sp<IMemory> memory;
65 hidl_memory hidlMemory;
66 uint8_t* buffer;
67
68 bool set(const hidl_memory& hidlMemory);
69 bool update();
70 };
71
72 bool setRunTimePoolInfosFromHidlMemories(std::vector<RunTimePoolInfo>* poolInfos,
73 const hidl_vec<hidl_memory>& pools);
74
75 // This class is used to execute a model on the CPU.
76 class CpuExecutor {
77 public:
78 // Executes the model. The results will be stored at the locations
79 // specified in the constructor.
80 // The model must outlive the executor. We prevent it from being modified
81 // while this is executing.
82 int run(const Model& model, const Request& request,
83 const std::vector<RunTimePoolInfo>& modelPoolInfos,
84 const std::vector<RunTimePoolInfo>& requestPoolInfos);
85
86 private:
87 bool initializeRunTimeInfo(const std::vector<RunTimePoolInfo>& modelPoolInfos,
88 const std::vector<RunTimePoolInfo>& requestPoolInfos);
89 // Runs one operation of the graph.
90 int executeOperation(const Operation& entry);
91 // Decrement the usage count for the operands listed. Frees the memory
92 // allocated for any temporary variable with a count of zero.
93 void freeNoLongerUsedOperands(const std::vector<uint32_t>& inputs);
94
95 // The model and the request that we'll execute. Only valid while run()
96 // is being executed.
97 const Model* mModel = nullptr;
98 const Request* mRequest = nullptr;
99
100 // We're copying the list of all the dimensions from the model, as
101 // these may be modified when we run the operatins. Since we're
102 // making a full copy, the indexes used in the operand description
103 // stay valid.
104 // std::vector<uint32_t> mDimensions;
105 // Runtime information about all the operands.
106 std::vector<RunTimeOperandInfo> mOperands;
107 };
108
109 namespace {
110
111 template <typename T>
getScalarData(const RunTimeOperandInfo & info)112 T getScalarData(const RunTimeOperandInfo& info) {
113 // TODO: Check buffer is at least as long as size of data.
114 T* data = reinterpret_cast<T*>(info.buffer);
115 return data[0];
116 }
117
IsNullInput(const RunTimeOperandInfo * input)118 inline bool IsNullInput(const RunTimeOperandInfo *input) {
119 return input->lifetime == OperandLifeTime::NO_VALUE;
120 }
121
NumInputsWithValues(const Operation & operation,std::vector<RunTimeOperandInfo> & operands)122 inline int NumInputsWithValues(const Operation &operation,
123 std::vector<RunTimeOperandInfo> &operands) {
124 const std::vector<uint32_t> &inputs = operation.inputs;
125 return std::count_if(inputs.begin(), inputs.end(),
126 [&operands](uint32_t i) {
127 return !IsNullInput(&operands[i]);
128 });
129 }
130
NumOutputs(const Operation & operation)131 inline int NumOutputs(const Operation &operation) {
132 return operation.outputs.size();
133 }
134
NumDimensions(const RunTimeOperandInfo * operand)135 inline size_t NumDimensions(const RunTimeOperandInfo *operand) {
136 return operand->shape().dimensions.size();
137 }
138
SizeOfDimension(const RunTimeOperandInfo * operand,int i)139 inline uint32_t SizeOfDimension(const RunTimeOperandInfo *operand, int i) {
140 return operand->shape().dimensions[i];
141 }
142
GetInput(const Operation & operation,std::vector<RunTimeOperandInfo> & operands,int index)143 inline RunTimeOperandInfo *GetInput(const Operation &operation,
144 std::vector<RunTimeOperandInfo> &operands,
145 int index) {
146 return &operands[operation.inputs[index]];
147 }
148
GetOutput(const Operation & operation,std::vector<RunTimeOperandInfo> & operands,int index)149 inline RunTimeOperandInfo *GetOutput(const Operation &operation,
150 std::vector<RunTimeOperandInfo> &operands,
151 int index) {
152 return &operands[operation.outputs[index]];
153 }
154
155 } // anonymous namespace
156
157 } // namespace nn
158 } // namespace android
159
160 #endif // ANDROID_ML_NN_COMMON_CPU_EXECUTOR_H
161