• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "fullconnection_builder.h"
17 
18 #include "transform.h"
19 #include "validation.h"
20 
21 namespace OHOS {
22 namespace NeuralNetworkRuntime {
23 namespace Ops {
24 static constexpr int INPUT_WITH_AXIS = 2;
25 static constexpr int INPUT_WITHOUT_AXIS = 1;
26 static constexpr int OUTPUT_NUM = 1;
27 static constexpr int SCALAR_LENGTH = 1;
28 static const std::string OP_NAME = "FullConnection";
29 
FullConnectionBuilder()30 FullConnectionBuilder::FullConnectionBuilder() {}
31 
~FullConnectionBuilder()32 FullConnectionBuilder::~FullConnectionBuilder() {}
33 
SetFullConnectionInput(const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)34 OH_NN_ReturnCode FullConnectionBuilder::SetFullConnectionInput(const std::vector<uint32_t>& inputsIndex,
35                                                                const std::vector<uint32_t>& outputsIndex,
36                                                                const std::vector<std::shared_ptr<NNTensor>>& allTensors)
37 {
38     if (outputsIndex.size() != OUTPUT_NUM) {
39         LOGE("[FullConnection] SetFullConnectionInput failed, the index of outputs don't equal to %d.", OUTPUT_NUM);
40         return OH_NN_INVALID_PARAMETER;
41     }
42     size_t allTensorsSize = allTensors.size();
43     bool isOverTensorSize = std::any_of(inputsIndex.begin(), inputsIndex.end(), [allTensorsSize](uint32_t index) {
44         return index >= allTensorsSize;
45     });
46     if (isOverTensorSize) {
47         LOGE("[FullConnection] SetFullConnectionInput failed, the index of inputs is out of range.");
48         return OH_NN_INVALID_PARAMETER;
49     }
50 
51     m_inputsIndex = inputsIndex;
52     m_outputsIndex = outputsIndex;
53 
54     return OH_NN_SUCCESS;
55 }
56 
SetFullConnectionActivation(std::shared_ptr<NNTensor> tensor)57 OH_NN_ReturnCode FullConnectionBuilder::SetFullConnectionActivation(std::shared_ptr<NNTensor> tensor)
58 {
59     tensor->IdentifyOpParameter();
60     // Set Activation
61     if (tensor->GetElementCount() != SCALAR_LENGTH) {
62         LOGE("[FullConnection] SetFullConnectionActivation failed, the Activation shoule be a scalar");
63         return OH_NN_INVALID_PARAMETER;
64     }
65 
66     if (tensor->GetDataType() != OH_NN_INT8) {
67         LOGE("[FullConnection] SetFullConnectionActivation failed, the Activation should have type OH_NN_INT8.");
68         return OH_NN_INVALID_PARAMETER;
69     }
70 
71     void* buffer = tensor->GetBuffer();
72     if (buffer == nullptr) {
73         LOGE("[FullConnection] SetFullConnectionActivation GetBuffer return nullptr");
74         return OH_NN_INVALID_PARAMETER;
75     }
76 
77     int8_t* pFuseData = static_cast<int8_t*>(tensor->GetBuffer());
78     if (!OHOS::NeuralNetworkRuntime::Validation::ValidateFuseType(static_cast<OH_NN_FuseType>(*pFuseData))) {
79         LOGE("[FullConnection] SetFullConnectionActivation failed, activation input is invalid.");
80         return OH_NN_INVALID_PARAMETER;
81     }
82     m_activationType = NNToMS::TransfromFusionType((OH_NN_FuseType)(*pFuseData));
83 
84     return OH_NN_SUCCESS;
85 }
86 
SetAxis(std::shared_ptr<NNTensor> tensor)87 OH_NN_ReturnCode FullConnectionBuilder::SetAxis(std::shared_ptr<NNTensor> tensor)
88 {
89     if (m_useAxis) {
90         tensor->IdentifyOpParameter();
91 
92         if (tensor->GetElementCount() != SCALAR_LENGTH) {
93             LOGE("[FullConnection] SetFullConnectionActivation failed, the axis shoule be a scalar");
94             return OH_NN_INVALID_PARAMETER;
95         }
96 
97         if (tensor->GetDataType() != OH_NN_INT64) {
98             LOGE("[FullConnection] SetFullConnectionActivation failed, the Axis should be type OH_NN_INT64.");
99             return OH_NN_INVALID_PARAMETER;
100         }
101 
102         void* buffer = tensor->GetBuffer();
103         if (buffer == nullptr) {
104             LOGE("[FullConnection] SetAxis GetBuffer return nullptr");
105             return OH_NN_INVALID_PARAMETER;
106         }
107 
108         m_axis = *static_cast<int64_t*>(buffer);
109     }
110     return OH_NN_SUCCESS;
111 }
112 
113 
Build(const std::vector<uint32_t> & paramsIndex,const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)114 OH_NN_ReturnCode FullConnectionBuilder::Build(const std::vector<uint32_t>& paramsIndex,
115                                               const std::vector<uint32_t>& inputsIndex,
116                                               const std::vector<uint32_t>& outputsIndex,
117                                               const std::vector<std::shared_ptr<NNTensor>>& allTensors)
118 {
119     if (m_isBuild) {
120         LOGE("[FullConnection] Build failed, operation has been build, cannot build again.");
121         return OH_NN_OPERATION_FORBIDDEN;
122     }
123 
124     bool useAxis = false;
125     if (paramsIndex.size() == INPUT_WITH_AXIS) {
126         useAxis = true;
127     } else if (paramsIndex.size() != INPUT_WITHOUT_AXIS) {
128         LOGE("[FullConnection] Build failed, the index of inputs should equal to %d if axis used or %d if not.",
129             INPUT_WITH_AXIS, INPUT_WITHOUT_AXIS);
130         return OH_NN_INVALID_PARAMETER;
131     }
132 
133     OH_NN_ReturnCode returnCode = SetFullConnectionInput(inputsIndex, outputsIndex, allTensors);
134     if (returnCode != OH_NN_SUCCESS) {
135         LOGE("[FullConnection] Build failed, SetFullConnectionInput failed.");
136         return returnCode;
137     }
138 
139     // Set axis
140     m_useAxis = useAxis;
141     for (int i : paramsIndex) {
142         std::shared_ptr<NNTensor> tensor = allTensors[i]; // 参数 tensor
143         switch (tensor->GetType()) {
144             case OH_NN_FULL_CONNECTION_AXIS:
145                 returnCode = SetAxis(tensor);
146                 break;
147             case OH_NN_FULL_CONNECTION_ACTIVATIONTYPE:
148                 returnCode = SetFullConnectionActivation(tensor);
149                 break;
150             default:
151                 LOGE("[FullConnection] Build failed, param invalid, type = %d.", tensor->GetType());
152                 return OH_NN_INVALID_PARAMETER;
153         }
154         if (returnCode != OH_NN_SUCCESS) {
155             LOGE("[FullConnection] Build failed, passed invalid param.");
156             return returnCode;
157         }
158     }
159 
160     // The quantization type of the first output determinies that of the operator.
161     SetQuantType(outputsIndex, allTensors);
162 
163     m_isBuild = true;
164     m_name = OP_NAME;
165     return OH_NN_SUCCESS;
166 }
167 
GetPrimitive()168 LiteGraphPrimitvePtr FullConnectionBuilder::GetPrimitive()
169 {
170     if (!m_isBuild) {
171         LOGE("[FullConnection] GetPrimitive failed, cannot get primitive before call build.");
172         return {nullptr, DestroyLiteGraphPrimitive};
173     }
174 
175     void* primitive = mindspore::lite::MindIR_FullConnection_CreatePrimitive(m_hasBias, m_useAxis,
176         m_axis, m_activationType);
177     LiteGraphPrimitvePtr graphPrimitivePtr(primitive, DestroyLiteGraphPrimitive) ;
178     return graphPrimitivePtr;
179 }
180 
181 REGISTER_OPS(FullConnectionBuilder, OH_NN_OPS_FULL_CONNECTION);
182 } // namespace Ops
183 } // namespace NeuralNetworkRuntime
184 } // namespace OHOS