1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "strided_slice_builder.h"
17
18 #include "mindir.h"
19
20 #include "interfaces/kits/c/neural_network_runtime_type.h"
21
22 namespace OHOS {
23 namespace NeuralNetworkRuntime {
24 namespace Ops {
25 static const int INPUT_NUM = 4;
26 static const int OUTPUT_NUM = 1;
27 static const std::string OP_NAME = "StridedSlice";
28
StridedSliceBuilder()29 StridedSliceBuilder::StridedSliceBuilder() {}
30
~StridedSliceBuilder()31 StridedSliceBuilder::~StridedSliceBuilder() {}
32
SetInputOutput(const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)33 OH_NN_ReturnCode StridedSliceBuilder::SetInputOutput(const std::vector<uint32_t>& inputsIndex,
34 const std::vector<uint32_t>& outputsIndex,
35 const std::vector<std::shared_ptr<NNTensor>>& allTensors)
36 {
37 OH_NN_ReturnCode returnCode = CheckIOIndex(inputsIndex, outputsIndex, allTensors, INPUT_NUM, OUTPUT_NUM);
38 if (returnCode != OH_NN_SUCCESS) {
39 LOGE("[StridedSliceBuilder] Passed invalid input or output index.");
40 return returnCode;
41 }
42
43 m_inputsIndex = inputsIndex;
44 m_outputsIndex = outputsIndex;
45
46 return OH_NN_SUCCESS;
47 }
48
SetBeginMask(std::shared_ptr<NNTensor> tensor)49 OH_NN_ReturnCode StridedSliceBuilder::SetBeginMask(std::shared_ptr<NNTensor> tensor)
50 {
51 if (tensor->GetDataType() != OH_NN_INT64) {
52 LOGE("[StridedSliceBuilder] The 5th input beginMask should be type HNN_INT64.");
53 return OH_NN_INVALID_PARAMETER;
54 }
55
56 void* buffer = tensor->GetBuffer();
57 if (buffer == nullptr) {
58 LOGE("[StridedSliceBuilder] Tensor buffer is nullptr.");
59 return OH_NN_INVALID_PARAMETER;
60 }
61 m_begin_mask = *(static_cast<int64_t*>(buffer));
62
63 return OH_NN_SUCCESS;
64 }
65
SetEndMask(std::shared_ptr<NNTensor> tensor)66 OH_NN_ReturnCode StridedSliceBuilder::SetEndMask(std::shared_ptr<NNTensor> tensor)
67 {
68 if (tensor->GetDataType() != OH_NN_INT64) {
69 LOGE("[StridedSliceBuilder] The 6th input endMask should be type HNN_INT64.");
70 return OH_NN_INVALID_PARAMETER;
71 }
72
73 void* buffer = tensor->GetBuffer();
74 if (buffer == nullptr) {
75 LOGE("[StridedSliceBuilder] Tensor buffer is nullptr.");
76 return OH_NN_INVALID_PARAMETER;
77 }
78 m_end_mask = *(static_cast<int64_t*>(buffer));
79
80 return OH_NN_SUCCESS;
81 }
82
SetEllipsisMask(std::shared_ptr<NNTensor> tensor)83 OH_NN_ReturnCode StridedSliceBuilder::SetEllipsisMask(std::shared_ptr<NNTensor> tensor)
84 {
85 if (tensor->GetDataType() != OH_NN_INT64) {
86 LOGE("[StridedSliceBuilder] The 7th input ellipsisMask should be type HNN_INT64.");
87 return OH_NN_INVALID_PARAMETER;
88 }
89
90 void* buffer = tensor->GetBuffer();
91 if (buffer == nullptr) {
92 LOGE("[StridedSliceBuilder] Tensor buffer is nullptr.");
93 return OH_NN_INVALID_PARAMETER;
94 }
95 m_ellipsis_mask = *(static_cast<int64_t*>(buffer));
96
97 return OH_NN_SUCCESS;
98 }
99
SetNewAxisMask(std::shared_ptr<NNTensor> tensor)100 OH_NN_ReturnCode StridedSliceBuilder::SetNewAxisMask(std::shared_ptr<NNTensor> tensor)
101 {
102 if (tensor->GetDataType() != OH_NN_INT64) {
103 LOGE("[StridedSliceBuilder] The 8th input newAxisMask should be type HNN_INT64.");
104 return OH_NN_INVALID_PARAMETER;
105 }
106
107 void* buffer = tensor->GetBuffer();
108 if (buffer == nullptr) {
109 LOGE("[StridedSliceBuilder] Tensor buffer is nullptr.");
110 return OH_NN_INVALID_PARAMETER;
111 }
112 m_new_axis_mask = *(static_cast<int64_t*>(buffer));
113
114 return OH_NN_SUCCESS;
115 }
116
SetShrinkAxisMask(std::shared_ptr<NNTensor> tensor)117 OH_NN_ReturnCode StridedSliceBuilder::SetShrinkAxisMask(std::shared_ptr<NNTensor> tensor)
118 {
119 if (tensor->GetDataType() != OH_NN_INT64) {
120 LOGE("[StridedSliceBuilder] The 9th input shrinkAxisMAsk should be type HNN_INT64.");
121 return OH_NN_INVALID_PARAMETER;
122 }
123
124 void* buffer = tensor->GetBuffer();
125 if (buffer == nullptr) {
126 LOGE("[StridedSliceBuilder] Tensor buffer is nullptr.");
127 return OH_NN_INVALID_PARAMETER;
128 }
129 m_shrink_axis_mask = *(static_cast<int64_t*>(buffer));
130
131 return OH_NN_SUCCESS;
132 }
133
134 /**
135 * Build method.
136 * 1.set attr of ops.
137 * 2.set inputIndex of ops.
138 * 3.set outputIndex of ops.
139 */
Build(const std::vector<uint32_t> & paramsIndex,const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)140 OH_NN_ReturnCode StridedSliceBuilder::Build(const std::vector<uint32_t>& paramsIndex,
141 const std::vector<uint32_t>& inputsIndex,
142 const std::vector<uint32_t>& outputsIndex,
143 const std::vector<std::shared_ptr<NNTensor>>& allTensors)
144 {
145 if (m_isBuild) {
146 LOGE("[StridedSliceBuilder] StridedSlice operation has been build, cannot build again.");
147 return OH_NN_OPERATION_FORBIDDEN;
148 }
149
150 OH_NN_ReturnCode returnCode = SetInputOutput(inputsIndex, outputsIndex, allTensors);
151 if (returnCode != OH_NN_SUCCESS) {
152 LOGE("[StridedSliceBuilder] Set index of inputs or outputs failed.");
153 return returnCode;
154 }
155
156 for (int i : paramsIndex) {
157 std::shared_ptr<NNTensor> tensor = allTensors[i];
158 tensor->IdentifyOpParameter();
159 switch (tensor->GetType()) {
160 case OH_NN_STRIDED_SLICE_BEGIN_MASK:
161 returnCode = SetBeginMask(tensor);
162 break;
163 case OH_NN_STRIDED_SLICE_END_MASK:
164 returnCode = SetEndMask(tensor);
165 break;
166 case OH_NN_STRIDED_SLICE_ELLIPSIS_MASK:
167 returnCode = SetEllipsisMask(tensor);
168 break;
169 case OH_NN_STRIDED_SLICE_NEW_AXIS_MASK:
170 returnCode = SetNewAxisMask(tensor);
171 break;
172 case OH_NN_STRIDED_SLICE_SHRINK_AXIS_MASK:
173 returnCode = SetShrinkAxisMask(tensor);
174 break;
175 default:
176 LOGE("[StridedSliceBuilder] Parameter Type is invalid. type=%d", tensor->GetType());
177 return OH_NN_INVALID_PARAMETER;
178 }
179
180 if (returnCode != OH_NN_SUCCESS) {
181 LOGE("[StridedSliceBuilder] Passed invalid param.");
182 return returnCode;
183 }
184 }
185
186 m_isBuild = true;
187 m_name = OP_NAME;
188 return OH_NN_SUCCESS;
189 }
190
GetPrimitive()191 LiteGraphPrimitvePtr StridedSliceBuilder::GetPrimitive()
192 {
193 if (!m_isBuild) {
194 LOGE("[StridedSliceBuilder] Cannot get primitive before call build.");
195 return {nullptr, DestroyLiteGraphPrimitive};
196 }
197
198 auto primitive = mindspore::lite::MindIR_StridedSlice_CreatePrimitive(m_begin_mask, m_end_mask, m_ellipsis_mask,
199 m_new_axis_mask, m_shrink_axis_mask);
200 if (primitive == nullptr) {
201 LOGE("[StridedSliceBuilder] MindIR_StridedSlice_CreatePrimitive failed.");
202 return {nullptr, DestroyLiteGraphPrimitive};
203 }
204
205 LiteGraphPrimitvePtr graphPrimitivePtr(primitive, DestroyLiteGraphPrimitive);
206 return graphPrimitivePtr;
207 }
208
209 REGISTER_OPS(StridedSliceBuilder, OH_NN_OPS_STRIDED_SLICE);
210 } // namespace Ops
211 } // namespace NeuralNetworkRuntime
212 } // namespace OHOS
213