• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "split_builder.h"
17 
18 namespace OHOS {
19 namespace NeuralNetworkRuntime {
20 namespace Ops {
21 static const int INPUT_NUM = 1;
22 static const std::string OP_NAME = "Split";
23 
SplitBuilder()24 SplitBuilder::SplitBuilder() {}
25 
~SplitBuilder()26 SplitBuilder::~SplitBuilder() {}
27 
SetInputAndOutput(const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)28 OH_NN_ReturnCode SplitBuilder::SetInputAndOutput(const std::vector<uint32_t> &inputsIndex,
29     const std::vector<uint32_t> &outputsIndex, const std::vector<std::shared_ptr<NNTensor>> &allTensors)
30 {
31     auto inputSize = inputsIndex.size();
32     if (inputSize != INPUT_NUM) {
33         LOGE("[SplitBuilder] The number of inputsIndex should be %d, its number is %zu.", INPUT_NUM, inputSize);
34         return OH_NN_INVALID_PARAMETER;
35     }
36 
37     auto allTensorSize = allTensors.size();
38     bool isOverTensorSize = std::any_of(inputsIndex.begin(), inputsIndex.end(), [allTensorSize](uint32_t index) {
39         return index >= allTensorSize;
40     });
41     if (isOverTensorSize) {
42         LOGE("[SplitBuilder] InputsIndex of Split is out of range.");
43         return OH_NN_INVALID_PARAMETER;
44     }
45 
46     isOverTensorSize = std::any_of(outputsIndex.begin(), outputsIndex.end(), [allTensorSize](uint32_t index) {
47         return index >= allTensorSize;
48     });
49     if (isOverTensorSize) {
50         LOGE("[SplitBuilder] InputsIndex of Split is out of range.");
51         return OH_NN_INVALID_PARAMETER;
52     }
53 
54     m_inputsIndex = inputsIndex;
55     m_outputsIndex = outputsIndex;
56 
57     // The quantization type of the first output determinies that of the operator.
58     SetQuantType(outputsIndex, allTensors);
59 
60     return OH_NN_SUCCESS;
61 }
62 
SetAxis(std::shared_ptr<NNTensor> tensor)63 OH_NN_ReturnCode SplitBuilder::SetAxis(std::shared_ptr<NNTensor> tensor)
64 {
65     if (tensor->GetDataType() != OH_NN_INT64) {
66         LOGE("[SplitBuilder] The 4th input axis should be type OH_NN_INT64.");
67         return OH_NN_INVALID_PARAMETER;
68     }
69 
70     if (tensor->GetElementCount() != 1) {
71         LOGE("[SplitBuilder] The 4th input axis should be scaler.");
72         return OH_NN_INVALID_PARAMETER;
73     }
74 
75     void* buffer = tensor->GetBuffer();
76     if (buffer == nullptr) {
77         LOGE("[SplitBuilder] Tensor buffer is nullptr.");
78         return OH_NN_INVALID_PARAMETER;
79     }
80     m_axis = *(static_cast<const int64_t *>(buffer));
81 
82     return OH_NN_SUCCESS;
83 }
84 
SetOutputNum(std::shared_ptr<NNTensor> tensor)85 OH_NN_ReturnCode SplitBuilder::SetOutputNum(std::shared_ptr<NNTensor> tensor)
86 {
87     if (tensor->GetDataType() != OH_NN_INT64) {
88         LOGE("[SplitBuilder] The 2nd input outputNum should be type OH_NN_INT64.");
89         return OH_NN_INVALID_PARAMETER;
90     }
91 
92     if (tensor->GetElementCount() != 1) {
93         LOGE("[SoftmaxBuilder] The 2nd input outputNum should be scaler.");
94         return OH_NN_INVALID_PARAMETER;
95     }
96 
97     m_output_num = *(static_cast<const int64_t *>(tensor->GetBuffer()));
98 
99     return OH_NN_SUCCESS;
100 }
101 
SetSizeSplits(std::shared_ptr<NNTensor> tensor)102 OH_NN_ReturnCode SplitBuilder::SetSizeSplits(std::shared_ptr<NNTensor> tensor)
103 {
104     if (tensor->GetDataType() != OH_NN_INT64) {
105         LOGE("[SplitBuilder] The 3rd input sizeSplit should be type OH_NN_INT64.");
106         return OH_NN_INVALID_PARAMETER;
107     }
108 
109     const int64_t *size_splits_data_ptr = reinterpret_cast<const int64_t *>(tensor->GetBuffer());
110     for (uint32_t i = 0; i < tensor->GetElementCount(); i++) {
111         m_size_splits.push_back(*size_splits_data_ptr++);
112     }
113 
114     return OH_NN_SUCCESS;
115 }
116 
117 /**
118  * Build method.
119  * 1.set attr of ops.
120  * 2.set inputIndex of ops.
121  * 3.set outputIndex of ops.
122  */
Build(const std::vector<uint32_t> & paramsIndex,const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)123 OH_NN_ReturnCode SplitBuilder::Build(const std::vector<uint32_t> &paramsIndex,
124                                      const std::vector<uint32_t> &inputsIndex,
125                                      const std::vector<uint32_t> &outputsIndex,
126                                      const std::vector<std::shared_ptr<NNTensor>> &allTensors)
127 {
128     if (m_isBuild) {
129         LOGE("[SplitBuilder] Split operation has been build, cannot build again.");
130         return OH_NN_OPERATION_FORBIDDEN;
131     }
132 
133     OH_NN_ReturnCode returnCode = SetInputAndOutput(inputsIndex, outputsIndex, allTensors);
134     if (returnCode != OH_NN_SUCCESS) {
135         LOGE("[SplitBuilder] Set index of inputs or outputs failed.");
136         return returnCode;
137     }
138 
139     for (int i : paramsIndex) {
140         std::shared_ptr<NNTensor> tensor = allTensors[i];
141         tensor->IdentifyOpParameter();
142         switch (tensor->GetType()) {
143             case OH_NN_SPLIT_AXIS:
144                 returnCode = SetAxis(tensor);
145                 break;
146             case OH_NN_SPLIT_OUTPUT_NUM:
147                 returnCode = SetOutputNum(tensor);
148                 break;
149             case OH_NN_SPLIT_SIZE_SPLITS:
150                 returnCode = SetSizeSplits(tensor);
151                 break;
152             default:
153                 LOGE("[SplitBuilder] Parameter Type is invalid. type=%d", tensor->GetType());
154                 return OH_NN_INVALID_PARAMETER;
155         }
156 
157         if (returnCode != OH_NN_SUCCESS) {
158             LOGE("[SplitBuilder] Passed invalid param.");
159             return returnCode;
160         }
161     }
162 
163     m_isBuild = true;
164     m_name = OP_NAME;
165     return OH_NN_SUCCESS;
166 }
167 
GetPrimitive()168 LiteGraphTensorPtr SplitBuilder::GetPrimitive()
169 {
170     if (!m_isBuild) {
171         LOGE("[SplitBuilder] Cannot get primitive before call build.");
172         return { nullptr, DestroyLiteGraphPrimitive };
173     }
174 
175     auto primitive = mindspore::lite::MindIR_Split_CreatePrimitive(m_output_num, m_size_splits, m_axis);
176     if (primitive == nullptr) {
177         LOGE("[SplitBuilder] MindIR_Split_CreatePrimitive failed.");
178         return { nullptr, DestroyLiteGraphPrimitive };
179     }
180 
181     LiteGraphTensorPtr graphPrimitivePtr(primitive, DestroyLiteGraphPrimitive);
182     return graphPrimitivePtr;
183 }
184 
185 REGISTER_OPS(SplitBuilder, OH_NN_OPS_SPLIT);
186 } // namespace Ops
187 } // namespace NeuralNetworkRuntime
188 } // namespace OHOS
189