1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "split_builder.h"
17
18 namespace OHOS {
19 namespace NeuralNetworkRuntime {
20 namespace Ops {
21 static const int INPUT_NUM = 1;
22 static const std::string OP_NAME = "Split";
23
SplitBuilder()24 SplitBuilder::SplitBuilder() {}
25
~SplitBuilder()26 SplitBuilder::~SplitBuilder() {}
27
SetInputAndOutput(const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)28 OH_NN_ReturnCode SplitBuilder::SetInputAndOutput(const std::vector<uint32_t> &inputsIndex,
29 const std::vector<uint32_t> &outputsIndex, const std::vector<std::shared_ptr<NNTensor>> &allTensors)
30 {
31 auto inputSize = inputsIndex.size();
32 if (inputSize != INPUT_NUM) {
33 LOGE("[SplitBuilder] The number of inputsIndex should be %d, its number is %zu.", INPUT_NUM, inputSize);
34 return OH_NN_INVALID_PARAMETER;
35 }
36
37 auto allTensorSize = allTensors.size();
38 for (auto index : inputsIndex) {
39 if (index >= allTensorSize) {
40 LOGE("[SplitBuilder] InputsIndex of Split is out of range.");
41 return OH_NN_INVALID_PARAMETER;
42 }
43 }
44
45 for (auto index : outputsIndex) {
46 if (index >= allTensorSize) {
47 LOGE("[SplitBuilder] OutputsIndex of Split is out of range.");
48 return OH_NN_INVALID_PARAMETER;
49 }
50 }
51
52 m_inputsIndex = inputsIndex;
53 m_outputsIndex = outputsIndex;
54
55 // The quantization type of the first output determinies that of the operator.
56 SetQuantType(outputsIndex, allTensors);
57
58 return OH_NN_SUCCESS;
59 }
60
SetAxis(std::shared_ptr<NNTensor> tensor)61 OH_NN_ReturnCode SplitBuilder::SetAxis(std::shared_ptr<NNTensor> tensor)
62 {
63 if (tensor->GetDataType() != OH_NN_INT64) {
64 LOGE("[SplitBuilder] The 4th input axis should be type OH_NN_INT64.");
65 return OH_NN_INVALID_PARAMETER;
66 }
67
68 if (tensor->GetElementCount() != 1) {
69 LOGE("[SplitBuilder] The 4th input axis should be scaler.");
70 return OH_NN_INVALID_PARAMETER;
71 }
72
73 void* buffer = tensor->GetBuffer();
74 if (buffer == nullptr) {
75 LOGE("[SplitBuilder] Tensor buffer is nullptr.");
76 return OH_NN_INVALID_PARAMETER;
77 }
78 m_axis = *(static_cast<const int64_t *>(buffer));
79
80 return OH_NN_SUCCESS;
81 }
82
SetOutputNum(std::shared_ptr<NNTensor> tensor)83 OH_NN_ReturnCode SplitBuilder::SetOutputNum(std::shared_ptr<NNTensor> tensor)
84 {
85 if (tensor->GetDataType() != OH_NN_INT64) {
86 LOGE("[SplitBuilder] The 2nd input outputNum should be type OH_NN_INT64.");
87 return OH_NN_INVALID_PARAMETER;
88 }
89
90 if (tensor->GetElementCount() != 1) {
91 LOGE("[SoftmaxBuilder] The 2nd input outputNum should be scaler.");
92 return OH_NN_INVALID_PARAMETER;
93 }
94
95 m_output_num = *(static_cast<const int64_t *>(tensor->GetBuffer()));
96
97 return OH_NN_SUCCESS;
98 }
99
SetSizeSplits(std::shared_ptr<NNTensor> tensor)100 OH_NN_ReturnCode SplitBuilder::SetSizeSplits(std::shared_ptr<NNTensor> tensor)
101 {
102 if (tensor->GetDataType() != OH_NN_INT64) {
103 LOGE("[SplitBuilder] The 3rd input sizeSplit should be type OH_NN_INT64.");
104 return OH_NN_INVALID_PARAMETER;
105 }
106
107 const int64_t *size_splits_data_ptr = reinterpret_cast<const int64_t *>(tensor->GetBuffer());
108 for (uint32_t i = 0; i < tensor->GetElementCount(); i++) {
109 m_size_splits.push_back(*size_splits_data_ptr++);
110 }
111
112 return OH_NN_SUCCESS;
113 }
114
115 /**
116 * Build method.
117 * 1.set attr of ops.
118 * 2.set inputIndex of ops.
119 * 3.set outputIndex of ops.
120 */
Build(const std::vector<uint32_t> & paramsIndex,const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)121 OH_NN_ReturnCode SplitBuilder::Build(const std::vector<uint32_t> ¶msIndex,
122 const std::vector<uint32_t> &inputsIndex,
123 const std::vector<uint32_t> &outputsIndex,
124 const std::vector<std::shared_ptr<NNTensor>> &allTensors)
125 {
126 if (m_isBuild) {
127 LOGE("[SplitBuilder] Split operation has been build, cannot build again.");
128 return OH_NN_OPERATION_FORBIDDEN;
129 }
130
131 OH_NN_ReturnCode returnCode = SetInputAndOutput(inputsIndex, outputsIndex, allTensors);
132 if (returnCode != OH_NN_SUCCESS) {
133 LOGE("[SplitBuilder] Set index of inputs or outputs failed.");
134 return returnCode;
135 }
136
137 for (int i : paramsIndex) {
138 std::shared_ptr<NNTensor> tensor = allTensors[i];
139 tensor->IdentifyOpParameter();
140 switch (tensor->GetType()) {
141 case OH_NN_SPLIT_AXIS:
142 returnCode = SetAxis(tensor);
143 break;
144 case OH_NN_SPLIT_OUTPUT_NUM:
145 returnCode = SetOutputNum(tensor);
146 break;
147 case OH_NN_SPLIT_SIZE_SPLITS:
148 returnCode = SetSizeSplits(tensor);
149 break;
150 default:
151 LOGE("[SplitBuilder] Parameter Type is invalid. type=%d", tensor->GetType());
152 return OH_NN_INVALID_PARAMETER;
153 }
154
155 if (returnCode != OH_NN_SUCCESS) {
156 LOGE("[SplitBuilder] Passed invalid param.");
157 return returnCode;
158 }
159 }
160
161 m_isBuild = true;
162 m_name = OP_NAME;
163 return OH_NN_SUCCESS;
164 }
165
GetPrimitive()166 LiteGraphTensorPtr SplitBuilder::GetPrimitive()
167 {
168 if (!m_isBuild) {
169 LOGE("[SplitBuilder] Cannot get primitive before call build.");
170 return { nullptr, DestroyLiteGraphPrimitive };
171 }
172
173 auto primitive = mindspore::lite::MindIR_Split_CreatePrimitive(m_output_num, m_size_splits, m_axis);
174 if (primitive == nullptr) {
175 LOGE("[SplitBuilder] MindIR_Split_CreatePrimitive failed.");
176 return { nullptr, DestroyLiteGraphPrimitive };
177 }
178
179 LiteGraphTensorPtr graphPrimitivePtr(primitive, DestroyLiteGraphPrimitive);
180 return graphPrimitivePtr;
181 }
182
183 REGISTER_OPS(SplitBuilder, OH_NN_OPS_SPLIT);
184 } // namespace Ops
185 } // namespace NeuralNetworkRuntime
186 } // namespace OHOS
187