1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "split_builder.h"
17
18 namespace OHOS {
19 namespace NeuralNetworkRuntime {
20 namespace Ops {
21 static const int INPUT_NUM = 1;
22 static const std::string OP_NAME = "Split";
23
SplitBuilder()24 SplitBuilder::SplitBuilder() {}
25
~SplitBuilder()26 SplitBuilder::~SplitBuilder() {}
27
SetInputAndOutput(const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)28 OH_NN_ReturnCode SplitBuilder::SetInputAndOutput(const std::vector<uint32_t> &inputsIndex,
29 const std::vector<uint32_t> &outputsIndex, const std::vector<std::shared_ptr<NNTensor>> &allTensors)
30 {
31 auto inputSize = inputsIndex.size();
32 if (inputSize != INPUT_NUM) {
33 LOGE("[SplitBuilder] The number of inputsIndex should be %d, its number is %zu.", INPUT_NUM, inputSize);
34 return OH_NN_INVALID_PARAMETER;
35 }
36
37 auto allTensorSize = allTensors.size();
38 bool isOverTensorSize = std::any_of(inputsIndex.begin(), inputsIndex.end(), [allTensorSize](uint32_t index) {
39 return index >= allTensorSize;
40 });
41 if (isOverTensorSize) {
42 LOGE("[SplitBuilder] InputsIndex of Split is out of range.");
43 return OH_NN_INVALID_PARAMETER;
44 }
45
46 isOverTensorSize = std::any_of(outputsIndex.begin(), outputsIndex.end(), [allTensorSize](uint32_t index) {
47 return index >= allTensorSize;
48 });
49 if (isOverTensorSize) {
50 LOGE("[SplitBuilder] InputsIndex of Split is out of range.");
51 return OH_NN_INVALID_PARAMETER;
52 }
53
54 m_inputsIndex = inputsIndex;
55 m_outputsIndex = outputsIndex;
56
57 // The quantization type of the first output determinies that of the operator.
58 SetQuantType(outputsIndex, allTensors);
59
60 return OH_NN_SUCCESS;
61 }
62
SetAxis(std::shared_ptr<NNTensor> tensor)63 OH_NN_ReturnCode SplitBuilder::SetAxis(std::shared_ptr<NNTensor> tensor)
64 {
65 if (tensor->GetDataType() != OH_NN_INT64) {
66 LOGE("[SplitBuilder] The 4th input axis should be type OH_NN_INT64.");
67 return OH_NN_INVALID_PARAMETER;
68 }
69
70 if (tensor->GetElementCount() != 1) {
71 LOGE("[SplitBuilder] The 4th input axis should be scaler.");
72 return OH_NN_INVALID_PARAMETER;
73 }
74
75 void* buffer = tensor->GetBuffer();
76 if (buffer == nullptr) {
77 LOGE("[SplitBuilder] Tensor buffer is nullptr.");
78 return OH_NN_INVALID_PARAMETER;
79 }
80 m_axis = *(static_cast<const int64_t *>(buffer));
81
82 return OH_NN_SUCCESS;
83 }
84
SetOutputNum(std::shared_ptr<NNTensor> tensor)85 OH_NN_ReturnCode SplitBuilder::SetOutputNum(std::shared_ptr<NNTensor> tensor)
86 {
87 if (tensor->GetDataType() != OH_NN_INT64) {
88 LOGE("[SplitBuilder] The 2nd input outputNum should be type OH_NN_INT64.");
89 return OH_NN_INVALID_PARAMETER;
90 }
91
92 if (tensor->GetElementCount() != 1) {
93 LOGE("[SoftmaxBuilder] The 2nd input outputNum should be scaler.");
94 return OH_NN_INVALID_PARAMETER;
95 }
96
97 m_output_num = *(static_cast<const int64_t *>(tensor->GetBuffer()));
98
99 return OH_NN_SUCCESS;
100 }
101
SetSizeSplits(std::shared_ptr<NNTensor> tensor)102 OH_NN_ReturnCode SplitBuilder::SetSizeSplits(std::shared_ptr<NNTensor> tensor)
103 {
104 if (tensor->GetDataType() != OH_NN_INT64) {
105 LOGE("[SplitBuilder] The 3rd input sizeSplit should be type OH_NN_INT64.");
106 return OH_NN_INVALID_PARAMETER;
107 }
108
109 const int64_t *size_splits_data_ptr = reinterpret_cast<const int64_t *>(tensor->GetBuffer());
110 for (uint32_t i = 0; i < tensor->GetElementCount(); i++) {
111 m_size_splits.push_back(*size_splits_data_ptr++);
112 }
113
114 return OH_NN_SUCCESS;
115 }
116
117 /**
118 * Build method.
119 * 1.set attr of ops.
120 * 2.set inputIndex of ops.
121 * 3.set outputIndex of ops.
122 */
Build(const std::vector<uint32_t> & paramsIndex,const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)123 OH_NN_ReturnCode SplitBuilder::Build(const std::vector<uint32_t> ¶msIndex,
124 const std::vector<uint32_t> &inputsIndex,
125 const std::vector<uint32_t> &outputsIndex,
126 const std::vector<std::shared_ptr<NNTensor>> &allTensors)
127 {
128 if (m_isBuild) {
129 LOGE("[SplitBuilder] Split operation has been build, cannot build again.");
130 return OH_NN_OPERATION_FORBIDDEN;
131 }
132
133 OH_NN_ReturnCode returnCode = SetInputAndOutput(inputsIndex, outputsIndex, allTensors);
134 if (returnCode != OH_NN_SUCCESS) {
135 LOGE("[SplitBuilder] Set index of inputs or outputs failed.");
136 return returnCode;
137 }
138
139 for (int i : paramsIndex) {
140 std::shared_ptr<NNTensor> tensor = allTensors[i];
141 tensor->IdentifyOpParameter();
142 switch (tensor->GetType()) {
143 case OH_NN_SPLIT_AXIS:
144 returnCode = SetAxis(tensor);
145 break;
146 case OH_NN_SPLIT_OUTPUT_NUM:
147 returnCode = SetOutputNum(tensor);
148 break;
149 case OH_NN_SPLIT_SIZE_SPLITS:
150 returnCode = SetSizeSplits(tensor);
151 break;
152 default:
153 LOGE("[SplitBuilder] Parameter Type is invalid. type=%d", tensor->GetType());
154 return OH_NN_INVALID_PARAMETER;
155 }
156
157 if (returnCode != OH_NN_SUCCESS) {
158 LOGE("[SplitBuilder] Passed invalid param.");
159 return returnCode;
160 }
161 }
162
163 m_isBuild = true;
164 m_name = OP_NAME;
165 return OH_NN_SUCCESS;
166 }
167
GetPrimitive()168 LiteGraphTensorPtr SplitBuilder::GetPrimitive()
169 {
170 if (!m_isBuild) {
171 LOGE("[SplitBuilder] Cannot get primitive before call build.");
172 return { nullptr, DestroyLiteGraphPrimitive };
173 }
174
175 auto primitive = mindspore::lite::MindIR_Split_CreatePrimitive(m_output_num, m_size_splits, m_axis);
176 if (primitive == nullptr) {
177 LOGE("[SplitBuilder] MindIR_Split_CreatePrimitive failed.");
178 return { nullptr, DestroyLiteGraphPrimitive };
179 }
180
181 LiteGraphTensorPtr graphPrimitivePtr(primitive, DestroyLiteGraphPrimitive);
182 return graphPrimitivePtr;
183 }
184
185 REGISTER_OPS(SplitBuilder, OH_NN_OPS_SPLIT);
186 } // namespace Ops
187 } // namespace NeuralNetworkRuntime
188 } // namespace OHOS
189