• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "split_builder.h"
17 
18 namespace OHOS {
19 namespace NeuralNetworkRuntime {
20 namespace Ops {
21 static const int INPUT_NUM = 1;
22 static const std::string OP_NAME = "Split";
23 
SplitBuilder()24 SplitBuilder::SplitBuilder() {}
25 
~SplitBuilder()26 SplitBuilder::~SplitBuilder() {}
27 
SetInputAndOutput(const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)28 OH_NN_ReturnCode SplitBuilder::SetInputAndOutput(const std::vector<uint32_t> &inputsIndex,
29     const std::vector<uint32_t> &outputsIndex, const std::vector<std::shared_ptr<NNTensor>> &allTensors)
30 {
31     auto inputSize = inputsIndex.size();
32     if (inputSize != INPUT_NUM) {
33         LOGE("[SplitBuilder] The number of inputsIndex should be %d, its number is %zu.", INPUT_NUM, inputSize);
34         return OH_NN_INVALID_PARAMETER;
35     }
36 
37     auto allTensorSize = allTensors.size();
38     for (auto index : inputsIndex) {
39         if (index >= allTensorSize) {
40             LOGE("[SplitBuilder] InputsIndex of Split is out of range.");
41             return OH_NN_INVALID_PARAMETER;
42         }
43     }
44 
45     for (auto index : outputsIndex) {
46         if (index >= allTensorSize) {
47             LOGE("[SplitBuilder] OutputsIndex of Split is out of range.");
48             return OH_NN_INVALID_PARAMETER;
49         }
50     }
51 
52     m_inputsIndex = inputsIndex;
53     m_outputsIndex = outputsIndex;
54 
55     // The quantization type of the first output determinies that of the operator.
56     SetQuantType(outputsIndex, allTensors);
57 
58     return OH_NN_SUCCESS;
59 }
60 
SetAxis(std::shared_ptr<NNTensor> tensor)61 OH_NN_ReturnCode SplitBuilder::SetAxis(std::shared_ptr<NNTensor> tensor)
62 {
63     if (tensor->GetDataType() != OH_NN_INT64) {
64         LOGE("[SplitBuilder] The 4th input axis should be type OH_NN_INT64.");
65         return OH_NN_INVALID_PARAMETER;
66     }
67 
68     if (tensor->GetElementCount() != 1) {
69         LOGE("[SplitBuilder] The 4th input axis should be scaler.");
70         return OH_NN_INVALID_PARAMETER;
71     }
72 
73     void* buffer = tensor->GetBuffer();
74     if (buffer == nullptr) {
75         LOGE("[SplitBuilder] Tensor buffer is nullptr.");
76         return OH_NN_INVALID_PARAMETER;
77     }
78     m_axis = *(static_cast<const int64_t *>(buffer));
79 
80     return OH_NN_SUCCESS;
81 }
82 
SetOutputNum(std::shared_ptr<NNTensor> tensor)83 OH_NN_ReturnCode SplitBuilder::SetOutputNum(std::shared_ptr<NNTensor> tensor)
84 {
85     if (tensor->GetDataType() != OH_NN_INT64) {
86         LOGE("[SplitBuilder] The 2nd input outputNum should be type OH_NN_INT64.");
87         return OH_NN_INVALID_PARAMETER;
88     }
89 
90     if (tensor->GetElementCount() != 1) {
91         LOGE("[SoftmaxBuilder] The 2nd input outputNum should be scaler.");
92         return OH_NN_INVALID_PARAMETER;
93     }
94 
95     m_output_num = *(static_cast<const int64_t *>(tensor->GetBuffer()));
96 
97     return OH_NN_SUCCESS;
98 }
99 
SetSizeSplits(std::shared_ptr<NNTensor> tensor)100 OH_NN_ReturnCode SplitBuilder::SetSizeSplits(std::shared_ptr<NNTensor> tensor)
101 {
102     if (tensor->GetDataType() != OH_NN_INT64) {
103         LOGE("[SplitBuilder] The 3rd input sizeSplit should be type OH_NN_INT64.");
104         return OH_NN_INVALID_PARAMETER;
105     }
106 
107     const int64_t *size_splits_data_ptr = reinterpret_cast<const int64_t *>(tensor->GetBuffer());
108     for (uint32_t i = 0; i < tensor->GetElementCount(); i++) {
109         m_size_splits.push_back(*size_splits_data_ptr++);
110     }
111 
112     return OH_NN_SUCCESS;
113 }
114 
115 /**
116  * Build method.
117  * 1.set attr of ops.
118  * 2.set inputIndex of ops.
119  * 3.set outputIndex of ops.
120  */
Build(const std::vector<uint32_t> & paramsIndex,const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)121 OH_NN_ReturnCode SplitBuilder::Build(const std::vector<uint32_t> &paramsIndex,
122                                      const std::vector<uint32_t> &inputsIndex,
123                                      const std::vector<uint32_t> &outputsIndex,
124                                      const std::vector<std::shared_ptr<NNTensor>> &allTensors)
125 {
126     if (m_isBuild) {
127         LOGE("[SplitBuilder] Split operation has been build, cannot build again.");
128         return OH_NN_OPERATION_FORBIDDEN;
129     }
130 
131     OH_NN_ReturnCode returnCode = SetInputAndOutput(inputsIndex, outputsIndex, allTensors);
132     if (returnCode != OH_NN_SUCCESS) {
133         LOGE("[SplitBuilder] Set index of inputs or outputs failed.");
134         return returnCode;
135     }
136 
137     for (int i : paramsIndex) {
138         std::shared_ptr<NNTensor> tensor = allTensors[i];
139         tensor->IdentifyOpParameter();
140         switch (tensor->GetType()) {
141             case OH_NN_SPLIT_AXIS:
142                 returnCode = SetAxis(tensor);
143                 break;
144             case OH_NN_SPLIT_OUTPUT_NUM:
145                 returnCode = SetOutputNum(tensor);
146                 break;
147             case OH_NN_SPLIT_SIZE_SPLITS:
148                 returnCode = SetSizeSplits(tensor);
149                 break;
150             default:
151                 LOGE("[SplitBuilder] Parameter Type is invalid. type=%d", tensor->GetType());
152                 return OH_NN_INVALID_PARAMETER;
153         }
154 
155         if (returnCode != OH_NN_SUCCESS) {
156             LOGE("[SplitBuilder] Passed invalid param.");
157             return returnCode;
158         }
159     }
160 
161     m_isBuild = true;
162     m_name = OP_NAME;
163     return OH_NN_SUCCESS;
164 }
165 
GetPrimitive()166 LiteGraphTensorPtr SplitBuilder::GetPrimitive()
167 {
168     if (!m_isBuild) {
169         LOGE("[SplitBuilder] Cannot get primitive before call build.");
170         return { nullptr, DestroyLiteGraphPrimitive };
171     }
172 
173     auto primitive = mindspore::lite::MindIR_Split_CreatePrimitive(m_output_num, m_size_splits, m_axis);
174     if (primitive == nullptr) {
175         LOGE("[SplitBuilder] MindIR_Split_CreatePrimitive failed.");
176         return { nullptr, DestroyLiteGraphPrimitive };
177     }
178 
179     LiteGraphTensorPtr graphPrimitivePtr(primitive, DestroyLiteGraphPrimitive);
180     return graphPrimitivePtr;
181 }
182 
183 REGISTER_OPS(SplitBuilder, OH_NN_OPS_SPLIT);
184 } // namespace Ops
185 } // namespace NeuralNetworkRuntime
186 } // namespace OHOS
187