1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "fullconnection_builder.h"
17
18 #include "transform.h"
19 #include "validation.h"
20
21 namespace OHOS {
22 namespace NeuralNetworkRuntime {
23 namespace Ops {
24 static constexpr int INPUT_WITH_AXIS = 2;
25 static constexpr int INPUT_WITHOUT_AXIS = 1;
26 static constexpr int OUTPUT_NUM = 1;
27 static constexpr int SCALAR_LENGTH = 1;
28 static const std::string OP_NAME = "FullConnection";
29
FullConnectionBuilder()30 FullConnectionBuilder::FullConnectionBuilder() {}
31
~FullConnectionBuilder()32 FullConnectionBuilder::~FullConnectionBuilder() {}
33
SetFullConnectionInput(const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)34 OH_NN_ReturnCode FullConnectionBuilder::SetFullConnectionInput(const std::vector<uint32_t>& inputsIndex,
35 const std::vector<uint32_t>& outputsIndex,
36 const std::vector<std::shared_ptr<NNTensor>>& allTensors)
37 {
38 if (outputsIndex.size() != OUTPUT_NUM) {
39 LOGE("[FullConnection] SetFullConnectionInput failed, the index of outputs don't equal to %d.", OUTPUT_NUM);
40 return OH_NN_INVALID_PARAMETER;
41 }
42 size_t allTensorsSize = allTensors.size();
43 bool isOverTensorSize = std::any_of(inputsIndex.begin(), inputsIndex.end(), [allTensorsSize](uint32_t index) {
44 return index >= allTensorsSize;
45 });
46 if (isOverTensorSize) {
47 LOGE("[FullConnection] SetFullConnectionInput failed, the index of inputs is out of range.");
48 return OH_NN_INVALID_PARAMETER;
49 }
50
51 m_inputsIndex = inputsIndex;
52 m_outputsIndex = outputsIndex;
53
54 return OH_NN_SUCCESS;
55 }
56
SetFullConnectionActivation(std::shared_ptr<NNTensor> tensor)57 OH_NN_ReturnCode FullConnectionBuilder::SetFullConnectionActivation(std::shared_ptr<NNTensor> tensor)
58 {
59 tensor->IdentifyOpParameter();
60 // Set Activation
61 if (tensor->GetElementCount() != SCALAR_LENGTH) {
62 LOGE("[FullConnection] SetFullConnectionActivation failed, the Activation shoule be a scalar");
63 return OH_NN_INVALID_PARAMETER;
64 }
65
66 if (tensor->GetDataType() != OH_NN_INT8) {
67 LOGE("[FullConnection] SetFullConnectionActivation failed, the Activation should have type OH_NN_INT8.");
68 return OH_NN_INVALID_PARAMETER;
69 }
70
71 void* buffer = tensor->GetBuffer();
72 if (buffer == nullptr) {
73 LOGE("[FullConnection] SetFullConnectionActivation GetBuffer return nullptr");
74 return OH_NN_INVALID_PARAMETER;
75 }
76
77 int8_t* pFuseData = static_cast<int8_t*>(tensor->GetBuffer());
78 if (!OHOS::NeuralNetworkRuntime::Validation::ValidateFuseType(static_cast<OH_NN_FuseType>(*pFuseData))) {
79 LOGE("[FullConnection] SetFullConnectionActivation failed, activation input is invalid.");
80 return OH_NN_INVALID_PARAMETER;
81 }
82 m_activationType = NNToMS::TransfromFusionType((OH_NN_FuseType)(*pFuseData));
83
84 return OH_NN_SUCCESS;
85 }
86
SetAxis(std::shared_ptr<NNTensor> tensor)87 OH_NN_ReturnCode FullConnectionBuilder::SetAxis(std::shared_ptr<NNTensor> tensor)
88 {
89 if (m_useAxis) {
90 tensor->IdentifyOpParameter();
91
92 if (tensor->GetElementCount() != SCALAR_LENGTH) {
93 LOGE("[FullConnection] SetFullConnectionActivation failed, the axis shoule be a scalar");
94 return OH_NN_INVALID_PARAMETER;
95 }
96
97 if (tensor->GetDataType() != OH_NN_INT64) {
98 LOGE("[FullConnection] SetFullConnectionActivation failed, the Axis should be type OH_NN_INT64.");
99 return OH_NN_INVALID_PARAMETER;
100 }
101
102 void* buffer = tensor->GetBuffer();
103 if (buffer == nullptr) {
104 LOGE("[FullConnection] SetAxis GetBuffer return nullptr");
105 return OH_NN_INVALID_PARAMETER;
106 }
107
108 m_axis = *static_cast<int64_t*>(buffer);
109 }
110 return OH_NN_SUCCESS;
111 }
112
113
Build(const std::vector<uint32_t> & paramsIndex,const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)114 OH_NN_ReturnCode FullConnectionBuilder::Build(const std::vector<uint32_t>& paramsIndex,
115 const std::vector<uint32_t>& inputsIndex,
116 const std::vector<uint32_t>& outputsIndex,
117 const std::vector<std::shared_ptr<NNTensor>>& allTensors)
118 {
119 if (m_isBuild) {
120 LOGE("[FullConnection] Build failed, operation has been build, cannot build again.");
121 return OH_NN_OPERATION_FORBIDDEN;
122 }
123
124 bool useAxis = false;
125 if (paramsIndex.size() == INPUT_WITH_AXIS) {
126 useAxis = true;
127 } else if (paramsIndex.size() != INPUT_WITHOUT_AXIS) {
128 LOGE("[FullConnection] Build failed, the index of inputs should equal to %d if axis used or %d if not.",
129 INPUT_WITH_AXIS, INPUT_WITHOUT_AXIS);
130 return OH_NN_INVALID_PARAMETER;
131 }
132
133 OH_NN_ReturnCode returnCode = SetFullConnectionInput(inputsIndex, outputsIndex, allTensors);
134 if (returnCode != OH_NN_SUCCESS) {
135 LOGE("[FullConnection] Build failed, SetFullConnectionInput failed.");
136 return returnCode;
137 }
138
139 // Set axis
140 m_useAxis = useAxis;
141 for (int i : paramsIndex) {
142 std::shared_ptr<NNTensor> tensor = allTensors[i]; // 参数 tensor
143 switch (tensor->GetType()) {
144 case OH_NN_FULL_CONNECTION_AXIS:
145 returnCode = SetAxis(tensor);
146 break;
147 case OH_NN_FULL_CONNECTION_ACTIVATIONTYPE:
148 returnCode = SetFullConnectionActivation(tensor);
149 break;
150 default:
151 LOGE("[FullConnection] Build failed, param invalid, type = %d.", tensor->GetType());
152 return OH_NN_INVALID_PARAMETER;
153 }
154 if (returnCode != OH_NN_SUCCESS) {
155 LOGE("[FullConnection] Build failed, passed invalid param.");
156 return returnCode;
157 }
158 }
159
160 // The quantization type of the first output determinies that of the operator.
161 SetQuantType(outputsIndex, allTensors);
162
163 m_isBuild = true;
164 m_name = OP_NAME;
165 return OH_NN_SUCCESS;
166 }
167
GetPrimitive()168 LiteGraphPrimitvePtr FullConnectionBuilder::GetPrimitive()
169 {
170 if (!m_isBuild) {
171 LOGE("[FullConnection] GetPrimitive failed, cannot get primitive before call build.");
172 return {nullptr, DestroyLiteGraphPrimitive};
173 }
174
175 void* primitive = mindspore::lite::MindIR_FullConnection_CreatePrimitive(m_hasBias, m_useAxis,
176 m_axis, m_activationType);
177 LiteGraphPrimitvePtr graphPrimitivePtr(primitive, DestroyLiteGraphPrimitive) ;
178 return graphPrimitivePtr;
179 }
180
181 REGISTER_OPS(FullConnectionBuilder, OH_NN_OPS_FULL_CONNECTION);
182 } // namespace Ops
183 } // namespace NeuralNetworkRuntime
184 } // namespace OHOS