1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "stack_builder.h"
17
18 #include "mindir.h"
19
20 namespace OHOS {
21 namespace NeuralNetworkRuntime {
22 namespace Ops {
23 static const int INPUT_MIN_NUM = 2;
24 static const int OUTPUT_NUM = 1;
25 static const std::string OP_NAME = "Stack";
26
StackBuilder()27 StackBuilder::StackBuilder() {}
28
~StackBuilder()29 StackBuilder::~StackBuilder() {}
30
SetAxis(std::shared_ptr<NNTensor> tensor)31 OH_NN_ReturnCode StackBuilder::SetAxis(std::shared_ptr<NNTensor> tensor)
32 {
33 if (tensor->GetDataType() != OH_NN_INT64) {
34 LOGE("[StackBuilder] The last input axis should be type OH_NN_INT64.");
35 return OH_NN_INVALID_PARAMETER;
36 }
37
38 if (tensor->GetElementCount() != 1) {
39 LOGE("[StackBuilder] The last input axis should be scaler.");
40 return OH_NN_INVALID_PARAMETER;
41 }
42
43 void* buffer = tensor->GetBuffer();
44 if (buffer == nullptr) {
45 LOGE("[StackBuilder] Tensor buffer is nullptr.");
46 return OH_NN_INVALID_PARAMETER;
47 }
48 m_axis = *(static_cast<int64_t*>(buffer));
49
50 return OH_NN_SUCCESS;
51 }
52
53 /**
54 * Build method.
55 * 1.set attr of ops.
56 * 2.set inputIndex of ops.
57 * 3.set outputIndex of ops.
58 */
Build(const std::vector<uint32_t> & paramsIndex,const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)59 OH_NN_ReturnCode StackBuilder::Build(const std::vector<uint32_t>& paramsIndex,
60 const std::vector<uint32_t>& inputsIndex,
61 const std::vector<uint32_t>& outputsIndex,
62 const std::vector<std::shared_ptr<NNTensor>>& allTensors)
63 {
64 if (m_isBuild) {
65 LOGE("[StackBuilder] Stack operation has been build, cannot build again.");
66 return OH_NN_OPERATION_FORBIDDEN;
67 }
68
69 if (inputsIndex.size() < INPUT_MIN_NUM) {
70 LOGE("[StackBuilder] The number of index of inputs don't larger than %d.", INPUT_MIN_NUM);
71 return OH_NN_INVALID_PARAMETER;
72 }
73 if (outputsIndex.size() != OUTPUT_NUM) {
74 LOGE("[StackBuilder] The number of index of outputs don't equal to %d.", OUTPUT_NUM);
75 return OH_NN_INVALID_PARAMETER;
76 }
77
78 m_inputsIndex = inputsIndex;
79 m_outputsIndex = outputsIndex;
80
81 OH_NN_ReturnCode returnCode;
82 for (int i : paramsIndex) {
83 std::shared_ptr<NNTensor> tensor = allTensors[i];
84 tensor->IdentifyOpParameter();
85 switch (tensor->GetType()) {
86 case OH_NN_STACK_AXIS:
87 returnCode = SetAxis(tensor);
88 break;
89 default:
90 LOGE("[StackBuilder] Parameter Type is invalid. type=%d", tensor->GetType());
91 return OH_NN_INVALID_PARAMETER;
92 }
93
94 if (returnCode != OH_NN_SUCCESS) {
95 LOGE("[StackBuilder] Passed invalid param.");
96 return returnCode;
97 }
98 }
99
100 m_isBuild = true;
101 m_name = OP_NAME;
102 return OH_NN_SUCCESS;
103 }
104
GetPrimitive()105 LiteGraphTensorPtr StackBuilder::GetPrimitive()
106 {
107 if (!m_isBuild) {
108 LOGE("[StackBuilder] Cannot get primitive before call build.");
109 return {nullptr, DestroyLiteGraphPrimitive};
110 }
111
112 auto primitive = mindspore::lite::MindIR_Stack_CreatePrimitive(m_axis);
113 if (primitive == nullptr) {
114 LOGE("[StackBuilder] MindIR_Stack_CreatePrimitive failed.");
115 return {nullptr, DestroyLiteGraphPrimitive};
116 }
117
118 LiteGraphTensorPtr graphPrimitivePtr(primitive, DestroyLiteGraphPrimitive);
119 return graphPrimitivePtr;
120 }
121
122 REGISTER_OPS(StackBuilder, OH_NN_OPS_STACK);
123 } // namespace Ops
124 } // namespace NeuralNetworkRuntime
125 } // namespace OHOS
126