• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "strided_slice_builder.h"
17 
18 #include "mindir.h"
19 
20 #include "interfaces/kits/c/neural_network_runtime_type.h"
21 
22 namespace OHOS {
23 namespace NeuralNetworkRuntime {
24 namespace Ops {
25 static const int INPUT_NUM = 4;
26 static const int OUTPUT_NUM = 1;
27 static const std::string OP_NAME = "StridedSlice";
28 
StridedSliceBuilder()29 StridedSliceBuilder::StridedSliceBuilder() {}
30 
~StridedSliceBuilder()31 StridedSliceBuilder::~StridedSliceBuilder() {}
32 
SetInputOutput(const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)33 OH_NN_ReturnCode StridedSliceBuilder::SetInputOutput(const std::vector<uint32_t>& inputsIndex,
34                                                      const std::vector<uint32_t>& outputsIndex,
35                                                      const std::vector<std::shared_ptr<NNTensor>>& allTensors)
36 {
37     OH_NN_ReturnCode returnCode = CheckIOIndex(inputsIndex, outputsIndex, allTensors, INPUT_NUM, OUTPUT_NUM);
38     if (returnCode != OH_NN_SUCCESS) {
39         LOGE("[StridedSliceBuilder] Passed invalid input or output index.");
40         return returnCode;
41     }
42 
43     m_inputsIndex = inputsIndex;
44     m_outputsIndex = outputsIndex;
45 
46     return OH_NN_SUCCESS;
47 }
48 
SetBeginMask(std::shared_ptr<NNTensor> tensor)49 OH_NN_ReturnCode StridedSliceBuilder::SetBeginMask(std::shared_ptr<NNTensor> tensor)
50 {
51     if (tensor->GetDataType() != OH_NN_INT64) {
52         LOGE("[StridedSliceBuilder] The 5th input beginMask should be type HNN_INT64.");
53         return OH_NN_INVALID_PARAMETER;
54     }
55 
56     void* buffer = tensor->GetBuffer();
57     if (buffer == nullptr) {
58         LOGE("[StridedSliceBuilder] Tensor buffer is nullptr.");
59         return OH_NN_INVALID_PARAMETER;
60     }
61     m_begin_mask = *(static_cast<int64_t*>(buffer));
62 
63     return OH_NN_SUCCESS;
64 }
65 
SetEndMask(std::shared_ptr<NNTensor> tensor)66 OH_NN_ReturnCode StridedSliceBuilder::SetEndMask(std::shared_ptr<NNTensor> tensor)
67 {
68     if (tensor->GetDataType() != OH_NN_INT64) {
69         LOGE("[StridedSliceBuilder] The 6th input endMask should be type HNN_INT64.");
70         return OH_NN_INVALID_PARAMETER;
71     }
72 
73     void* buffer = tensor->GetBuffer();
74     if (buffer == nullptr) {
75         LOGE("[StridedSliceBuilder] Tensor buffer is nullptr.");
76         return OH_NN_INVALID_PARAMETER;
77     }
78     m_end_mask = *(static_cast<int64_t*>(buffer));
79 
80     return OH_NN_SUCCESS;
81 }
82 
SetEllipsisMask(std::shared_ptr<NNTensor> tensor)83 OH_NN_ReturnCode StridedSliceBuilder::SetEllipsisMask(std::shared_ptr<NNTensor> tensor)
84 {
85     if (tensor->GetDataType() != OH_NN_INT64) {
86         LOGE("[StridedSliceBuilder] The 7th input ellipsisMask should be type HNN_INT64.");
87         return OH_NN_INVALID_PARAMETER;
88     }
89 
90     void* buffer = tensor->GetBuffer();
91     if (buffer == nullptr) {
92         LOGE("[StridedSliceBuilder] Tensor buffer is nullptr.");
93         return OH_NN_INVALID_PARAMETER;
94     }
95     m_ellipsis_mask = *(static_cast<int64_t*>(buffer));
96 
97     return OH_NN_SUCCESS;
98 }
99 
SetNewAxisMask(std::shared_ptr<NNTensor> tensor)100 OH_NN_ReturnCode StridedSliceBuilder::SetNewAxisMask(std::shared_ptr<NNTensor> tensor)
101 {
102     if (tensor->GetDataType() != OH_NN_INT64) {
103         LOGE("[StridedSliceBuilder] The 8th input newAxisMask should be type HNN_INT64.");
104         return OH_NN_INVALID_PARAMETER;
105     }
106 
107     void* buffer = tensor->GetBuffer();
108     if (buffer == nullptr) {
109         LOGE("[StridedSliceBuilder] Tensor buffer is nullptr.");
110         return OH_NN_INVALID_PARAMETER;
111     }
112     m_new_axis_mask = *(static_cast<int64_t*>(buffer));
113 
114     return OH_NN_SUCCESS;
115 }
116 
SetShrinkAxisMask(std::shared_ptr<NNTensor> tensor)117 OH_NN_ReturnCode StridedSliceBuilder::SetShrinkAxisMask(std::shared_ptr<NNTensor> tensor)
118 {
119     if (tensor->GetDataType() != OH_NN_INT64) {
120         LOGE("[StridedSliceBuilder] The 9th input shrinkAxisMAsk should be type HNN_INT64.");
121         return OH_NN_INVALID_PARAMETER;
122     }
123 
124     void* buffer = tensor->GetBuffer();
125     if (buffer == nullptr) {
126         LOGE("[StridedSliceBuilder] Tensor buffer is nullptr.");
127         return OH_NN_INVALID_PARAMETER;
128     }
129     m_shrink_axis_mask = *(static_cast<int64_t*>(buffer));
130 
131     return OH_NN_SUCCESS;
132 }
133 
134 /**
135  * Build method.
136  * 1.set attr of ops.
137  * 2.set inputIndex of ops.
138  * 3.set outputIndex of ops.
139  */
Build(const std::vector<uint32_t> & paramsIndex,const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)140 OH_NN_ReturnCode StridedSliceBuilder::Build(const std::vector<uint32_t>& paramsIndex,
141                                             const std::vector<uint32_t>& inputsIndex,
142                                             const std::vector<uint32_t>& outputsIndex,
143                                             const std::vector<std::shared_ptr<NNTensor>>& allTensors)
144 {
145     if (m_isBuild) {
146         LOGE("[StridedSliceBuilder] StridedSlice operation has been build, cannot build again.");
147         return OH_NN_OPERATION_FORBIDDEN;
148     }
149 
150     OH_NN_ReturnCode returnCode = SetInputOutput(inputsIndex, outputsIndex, allTensors);
151     if (returnCode != OH_NN_SUCCESS) {
152         LOGE("[StridedSliceBuilder] Set index of inputs or outputs failed.");
153         return returnCode;
154     }
155 
156     for (int i : paramsIndex) {
157         std::shared_ptr<NNTensor> tensor = allTensors[i];
158         tensor->IdentifyOpParameter();
159         switch (tensor->GetType()) {
160             case OH_NN_STRIDED_SLICE_BEGIN_MASK:
161                 returnCode = SetBeginMask(tensor);
162                 break;
163             case OH_NN_STRIDED_SLICE_END_MASK:
164                 returnCode = SetEndMask(tensor);
165                 break;
166             case OH_NN_STRIDED_SLICE_ELLIPSIS_MASK:
167                 returnCode = SetEllipsisMask(tensor);
168                 break;
169             case OH_NN_STRIDED_SLICE_NEW_AXIS_MASK:
170                 returnCode = SetNewAxisMask(tensor);
171                 break;
172             case OH_NN_STRIDED_SLICE_SHRINK_AXIS_MASK:
173                 returnCode = SetShrinkAxisMask(tensor);
174                 break;
175             default:
176                 LOGE("[StridedSliceBuilder] Parameter Type is invalid. type=%d", tensor->GetType());
177                 return OH_NN_INVALID_PARAMETER;
178         }
179 
180         if (returnCode != OH_NN_SUCCESS) {
181             LOGE("[StridedSliceBuilder] Passed invalid param.");
182             return returnCode;
183         }
184     }
185 
186     m_isBuild = true;
187     m_name = OP_NAME;
188     return OH_NN_SUCCESS;
189 }
190 
GetPrimitive()191 LiteGraphPrimitvePtr StridedSliceBuilder::GetPrimitive()
192 {
193     if (!m_isBuild) {
194         LOGE("[StridedSliceBuilder] Cannot get primitive before call build.");
195         return {nullptr, DestroyLiteGraphPrimitive};
196     }
197 
198     auto primitive = mindspore::lite::MindIR_StridedSlice_CreatePrimitive(m_begin_mask, m_end_mask, m_ellipsis_mask,
199         m_new_axis_mask, m_shrink_axis_mask);
200     if (primitive == nullptr) {
201         LOGE("[StridedSliceBuilder] MindIR_StridedSlice_CreatePrimitive failed.");
202         return {nullptr, DestroyLiteGraphPrimitive};
203     }
204 
205     LiteGraphPrimitvePtr graphPrimitivePtr(primitive, DestroyLiteGraphPrimitive);
206     return graphPrimitivePtr;
207 }
208 
209 REGISTER_OPS(StridedSliceBuilder, OH_NN_OPS_STRIDED_SLICE);
210 } // namespace Ops
211 } // namespace NeuralNetworkRuntime
212 } // namespace OHOS
213