• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "fullconnection_builder.h"
17 
18 #include "frameworks/native/transform.h"
19 #include "frameworks/native/validation.h"
20 
21 namespace OHOS {
22 namespace NeuralNetworkRuntime {
23 namespace Ops {
24 static constexpr int INPUT_WITH_AXIS = 2;
25 static constexpr int INPUT_WITHOUT_AXIS = 1;
26 static constexpr int OUTPUT_NUM = 1;
27 static constexpr int SCALAR_LENGTH = 1;
28 static const std::string OP_NAME = "FullConnection";
29 
FullConnectionBuilder()30 FullConnectionBuilder::FullConnectionBuilder() {}
31 
~FullConnectionBuilder()32 FullConnectionBuilder::~FullConnectionBuilder() {}
33 
SetFullConnectionInput(const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)34 OH_NN_ReturnCode FullConnectionBuilder::SetFullConnectionInput(const std::vector<uint32_t>& inputsIndex,
35                                                                const std::vector<uint32_t>& outputsIndex,
36                                                                const std::vector<std::shared_ptr<NNTensor>>& allTensors)
37 {
38     if (outputsIndex.size() != OUTPUT_NUM) {
39         LOGE("[FullConnection] SetFullConnectionInput failed, the index of outputs don't equal to %d.", OUTPUT_NUM);
40         return OH_NN_INVALID_PARAMETER;
41     }
42     size_t allTensorsSize = allTensors.size();
43     for (auto index : inputsIndex) {
44         if (index >= allTensorsSize) {
45             LOGE("[FullConnection] SetFullConnectionInput failed, the index of inputs is out of range.");
46             return OH_NN_INVALID_PARAMETER;
47         }
48     }
49 
50     m_inputsIndex = inputsIndex;
51     m_outputsIndex = outputsIndex;
52 
53     return OH_NN_SUCCESS;
54 }
55 
SetFullConnectionActivation(std::shared_ptr<NNTensor> tensor)56 OH_NN_ReturnCode FullConnectionBuilder::SetFullConnectionActivation(std::shared_ptr<NNTensor> tensor)
57 {
58     tensor->IdentifyOpParameter();
59     // Set Activation
60     if (tensor->GetElementCount() != SCALAR_LENGTH) {
61         LOGE("[FullConnection] SetFullConnectionActivation failed, the Activation shoule be a scalar");
62         return OH_NN_INVALID_PARAMETER;
63     }
64 
65     if (tensor->GetDataType() != OH_NN_INT8) {
66         LOGE("[FullConnection] SetFullConnectionActivation failed, the Activation should have type OH_NN_INT8.");
67         return OH_NN_INVALID_PARAMETER;
68     }
69 
70     void* buffer = tensor->GetBuffer();
71     if (buffer == nullptr) {
72         LOGE("[FullConnection] SetFullConnectionActivation GetBuffer return nullptr");
73         return OH_NN_INVALID_PARAMETER;
74     }
75 
76     int8_t* pFuseData = static_cast<int8_t*>(tensor->GetBuffer());
77     if (!OHOS::NeuralNetworkRuntime::Validation::ValidateFuseType(static_cast<OH_NN_FuseType>(*pFuseData))) {
78         LOGE("[FullConnection] SetFullConnectionActivation failed, activation input is invalid.");
79         return OH_NN_INVALID_PARAMETER;
80     }
81     m_activationType = NNToMS::TransfromFusionType((OH_NN_FuseType)(*pFuseData));
82 
83     return OH_NN_SUCCESS;
84 }
85 
SetAxis(std::shared_ptr<NNTensor> tensor)86 OH_NN_ReturnCode FullConnectionBuilder::SetAxis(std::shared_ptr<NNTensor> tensor)
87 {
88     if (m_useAxis) {
89         tensor->IdentifyOpParameter();
90 
91         if (tensor->GetElementCount() != SCALAR_LENGTH) {
92             LOGE("[FullConnection] SetFullConnectionActivation failed, the axis shoule be a scalar");
93             return OH_NN_INVALID_PARAMETER;
94         }
95 
96         if (tensor->GetDataType() != OH_NN_INT64) {
97             LOGE("[FullConnection] SetFullConnectionActivation failed, the Axis should be type OH_NN_INT64.");
98             return OH_NN_INVALID_PARAMETER;
99         }
100 
101         void* buffer = tensor->GetBuffer();
102         if (buffer == nullptr) {
103             LOGE("[FullConnection] SetAxis GetBuffer return nullptr");
104             return OH_NN_INVALID_PARAMETER;
105         }
106 
107         m_axis = *static_cast<int64_t*>(buffer);
108     }
109     return OH_NN_SUCCESS;
110 }
111 
112 
Build(const std::vector<uint32_t> & paramsIndex,const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)113 OH_NN_ReturnCode FullConnectionBuilder::Build(const std::vector<uint32_t>& paramsIndex,
114                                               const std::vector<uint32_t>& inputsIndex,
115                                               const std::vector<uint32_t>& outputsIndex,
116                                               const std::vector<std::shared_ptr<NNTensor>>& allTensors)
117 {
118     if (m_isBuild) {
119         LOGE("[FullConnection] Build failed, operation has been build, cannot build again.");
120         return OH_NN_OPERATION_FORBIDDEN;
121     }
122 
123     bool useAxis = false;
124     if (paramsIndex.size() == INPUT_WITH_AXIS) {
125         useAxis = true;
126     } else if (paramsIndex.size() != INPUT_WITHOUT_AXIS) {
127         LOGE("[FullConnection] Build failed, the index of inputs should equal to %d if axis used or %d if not.",
128             INPUT_WITH_AXIS, INPUT_WITHOUT_AXIS);
129         return OH_NN_INVALID_PARAMETER;
130     }
131 
132     OH_NN_ReturnCode returnCode = SetFullConnectionInput(inputsIndex, outputsIndex, allTensors);
133     if (returnCode != OH_NN_SUCCESS) {
134         LOGE("[FullConnection] Build failed, SetFullConnectionInput failed.");
135         return returnCode;
136     }
137 
138     // Set axis
139     m_useAxis = useAxis;
140     for (int i : paramsIndex) {
141         std::shared_ptr<NNTensor> tensor = allTensors[i]; // 参数 tensor
142         switch (tensor->GetType()) {
143             case OH_NN_FULL_CONNECTION_AXIS:
144                 returnCode = SetAxis(tensor);
145                 break;
146             case OH_NN_FULL_CONNECTION_ACTIVATIONTYPE:
147                 returnCode = SetFullConnectionActivation(tensor);
148                 break;
149             default:
150                 LOGE("[FullConnection] Build failed, param invalid, type = %d.", tensor->GetType());
151                 return OH_NN_INVALID_PARAMETER;
152         }
153         if (returnCode != OH_NN_SUCCESS) {
154             LOGE("[FullConnection] Build failed, passed invalid param.");
155             return returnCode;
156         }
157     }
158 
159     // The quantization type of the first output determinies that of the operator.
160     SetQuantType(outputsIndex, allTensors);
161 
162     m_isBuild = true;
163     m_name = OP_NAME;
164     return OH_NN_SUCCESS;
165 }
166 
GetPrimitive()167 LiteGraphPrimitvePtr FullConnectionBuilder::GetPrimitive()
168 {
169     if (!m_isBuild) {
170         LOGE("[FullConnection] GetPrimitive failed, cannot get primitive before call build.");
171         return {nullptr, DestroyLiteGraphPrimitive};
172     }
173 
174     void* primitive = mindspore::lite::MindIR_FullConnection_CreatePrimitive(m_hasBias, m_useAxis,
175         m_axis, m_activationType);
176     LiteGraphPrimitvePtr graphPrimitivePtr(primitive, DestroyLiteGraphPrimitive) ;
177     return graphPrimitivePtr;
178 }
179 
180 REGISTER_OPS(FullConnectionBuilder, OH_NN_OPS_FULL_CONNECTION);
181 } // namespace Ops
182 } // namespace NeuralNetworkRuntime
183 } // namespace OHOS