1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "fullconnection_builder.h"
17
18 #include "frameworks/native/transform.h"
19 #include "frameworks/native/validation.h"
20
21 namespace OHOS {
22 namespace NeuralNetworkRuntime {
23 namespace Ops {
24 static constexpr int INPUT_WITH_AXIS = 2;
25 static constexpr int INPUT_WITHOUT_AXIS = 1;
26 static constexpr int OUTPUT_NUM = 1;
27 static constexpr int SCALAR_LENGTH = 1;
28 static const std::string OP_NAME = "FullConnection";
29
FullConnectionBuilder()30 FullConnectionBuilder::FullConnectionBuilder() {}
31
~FullConnectionBuilder()32 FullConnectionBuilder::~FullConnectionBuilder() {}
33
SetFullConnectionInput(const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)34 OH_NN_ReturnCode FullConnectionBuilder::SetFullConnectionInput(const std::vector<uint32_t>& inputsIndex,
35 const std::vector<uint32_t>& outputsIndex,
36 const std::vector<std::shared_ptr<NNTensor>>& allTensors)
37 {
38 if (outputsIndex.size() != OUTPUT_NUM) {
39 LOGE("[FullConnection] SetFullConnectionInput failed, the index of outputs don't equal to %d.", OUTPUT_NUM);
40 return OH_NN_INVALID_PARAMETER;
41 }
42 size_t allTensorsSize = allTensors.size();
43 for (auto index : inputsIndex) {
44 if (index >= allTensorsSize) {
45 LOGE("[FullConnection] SetFullConnectionInput failed, the index of inputs is out of range.");
46 return OH_NN_INVALID_PARAMETER;
47 }
48 }
49
50 m_inputsIndex = inputsIndex;
51 m_outputsIndex = outputsIndex;
52
53 return OH_NN_SUCCESS;
54 }
55
SetFullConnectionActivation(std::shared_ptr<NNTensor> tensor)56 OH_NN_ReturnCode FullConnectionBuilder::SetFullConnectionActivation(std::shared_ptr<NNTensor> tensor)
57 {
58 tensor->IdentifyOpParameter();
59 // Set Activation
60 if (tensor->GetElementCount() != SCALAR_LENGTH) {
61 LOGE("[FullConnection] SetFullConnectionActivation failed, the Activation shoule be a scalar");
62 return OH_NN_INVALID_PARAMETER;
63 }
64
65 if (tensor->GetDataType() != OH_NN_INT8) {
66 LOGE("[FullConnection] SetFullConnectionActivation failed, the Activation should have type OH_NN_INT8.");
67 return OH_NN_INVALID_PARAMETER;
68 }
69
70 void* buffer = tensor->GetBuffer();
71 if (buffer == nullptr) {
72 LOGE("[FullConnection] SetFullConnectionActivation GetBuffer return nullptr");
73 return OH_NN_INVALID_PARAMETER;
74 }
75
76 int8_t* pFuseData = static_cast<int8_t*>(tensor->GetBuffer());
77 if (!OHOS::NeuralNetworkRuntime::Validation::ValidateFuseType(static_cast<OH_NN_FuseType>(*pFuseData))) {
78 LOGE("[FullConnection] SetFullConnectionActivation failed, activation input is invalid.");
79 return OH_NN_INVALID_PARAMETER;
80 }
81 m_activationType = NNToMS::TransfromFusionType((OH_NN_FuseType)(*pFuseData));
82
83 return OH_NN_SUCCESS;
84 }
85
SetAxis(std::shared_ptr<NNTensor> tensor)86 OH_NN_ReturnCode FullConnectionBuilder::SetAxis(std::shared_ptr<NNTensor> tensor)
87 {
88 if (m_useAxis) {
89 tensor->IdentifyOpParameter();
90
91 if (tensor->GetElementCount() != SCALAR_LENGTH) {
92 LOGE("[FullConnection] SetFullConnectionActivation failed, the axis shoule be a scalar");
93 return OH_NN_INVALID_PARAMETER;
94 }
95
96 if (tensor->GetDataType() != OH_NN_INT64) {
97 LOGE("[FullConnection] SetFullConnectionActivation failed, the Axis should be type OH_NN_INT64.");
98 return OH_NN_INVALID_PARAMETER;
99 }
100
101 void* buffer = tensor->GetBuffer();
102 if (buffer == nullptr) {
103 LOGE("[FullConnection] SetAxis GetBuffer return nullptr");
104 return OH_NN_INVALID_PARAMETER;
105 }
106
107 m_axis = *static_cast<int64_t*>(buffer);
108 }
109 return OH_NN_SUCCESS;
110 }
111
112
Build(const std::vector<uint32_t> & paramsIndex,const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)113 OH_NN_ReturnCode FullConnectionBuilder::Build(const std::vector<uint32_t>& paramsIndex,
114 const std::vector<uint32_t>& inputsIndex,
115 const std::vector<uint32_t>& outputsIndex,
116 const std::vector<std::shared_ptr<NNTensor>>& allTensors)
117 {
118 if (m_isBuild) {
119 LOGE("[FullConnection] Build failed, operation has been build, cannot build again.");
120 return OH_NN_OPERATION_FORBIDDEN;
121 }
122
123 bool useAxis = false;
124 if (paramsIndex.size() == INPUT_WITH_AXIS) {
125 useAxis = true;
126 } else if (paramsIndex.size() != INPUT_WITHOUT_AXIS) {
127 LOGE("[FullConnection] Build failed, the index of inputs should equal to %d if axis used or %d if not.",
128 INPUT_WITH_AXIS, INPUT_WITHOUT_AXIS);
129 return OH_NN_INVALID_PARAMETER;
130 }
131
132 OH_NN_ReturnCode returnCode = SetFullConnectionInput(inputsIndex, outputsIndex, allTensors);
133 if (returnCode != OH_NN_SUCCESS) {
134 LOGE("[FullConnection] Build failed, SetFullConnectionInput failed.");
135 return returnCode;
136 }
137
138 // Set axis
139 m_useAxis = useAxis;
140 for (int i : paramsIndex) {
141 std::shared_ptr<NNTensor> tensor = allTensors[i]; // 参数 tensor
142 switch (tensor->GetType()) {
143 case OH_NN_FULL_CONNECTION_AXIS:
144 returnCode = SetAxis(tensor);
145 break;
146 case OH_NN_FULL_CONNECTION_ACTIVATIONTYPE:
147 returnCode = SetFullConnectionActivation(tensor);
148 break;
149 default:
150 LOGE("[FullConnection] Build failed, param invalid, type = %d.", tensor->GetType());
151 return OH_NN_INVALID_PARAMETER;
152 }
153 if (returnCode != OH_NN_SUCCESS) {
154 LOGE("[FullConnection] Build failed, passed invalid param.");
155 return returnCode;
156 }
157 }
158
159 // The quantization type of the first output determinies that of the operator.
160 SetQuantType(outputsIndex, allTensors);
161
162 m_isBuild = true;
163 m_name = OP_NAME;
164 return OH_NN_SUCCESS;
165 }
166
GetPrimitive()167 LiteGraphPrimitvePtr FullConnectionBuilder::GetPrimitive()
168 {
169 if (!m_isBuild) {
170 LOGE("[FullConnection] GetPrimitive failed, cannot get primitive before call build.");
171 return {nullptr, DestroyLiteGraphPrimitive};
172 }
173
174 void* primitive = mindspore::lite::MindIR_FullConnection_CreatePrimitive(m_hasBias, m_useAxis,
175 m_axis, m_activationType);
176 LiteGraphPrimitvePtr graphPrimitivePtr(primitive, DestroyLiteGraphPrimitive) ;
177 return graphPrimitivePtr;
178 }
179
180 REGISTER_OPS(FullConnectionBuilder, OH_NN_OPS_FULL_CONNECTION);
181 } // namespace Ops
182 } // namespace NeuralNetworkRuntime
183 } // namespace OHOS