1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "sub_builder.h"
17 #include "frameworks/native/transform.h"
18 #include "frameworks/native/validation.h"
19
20 namespace OHOS {
21 namespace NeuralNetworkRuntime {
22 namespace Ops {
23 static const int INPUT_NUM = 2;
24 static const int OUTPUT_NUM = 1;
25 static const std::string OP_NAME = "Sub";
26
SubBuilder()27 SubBuilder::SubBuilder() {}
28
~SubBuilder()29 SubBuilder::~SubBuilder() {}
30
SetActivationType(std::shared_ptr<NNTensor> tensor)31 OH_NN_ReturnCode SubBuilder::SetActivationType(std::shared_ptr<NNTensor> tensor)
32 {
33 if (tensor->GetDataType() != OH_NN_INT8) {
34 LOGE("[SubBuilder] The 3rd input activation should be type OH_NN_INT8.");
35 return OH_NN_INVALID_PARAMETER;
36 }
37
38 if (tensor->GetElementCount() != 1) {
39 LOGE("[SubBuilder] The 3rd input activation should be scaler.");
40 return OH_NN_INVALID_PARAMETER;
41 }
42
43 void* buffer = tensor->GetBuffer();
44 if (buffer == nullptr) {
45 LOGE("[SubBuilder] Tensor buffer is nullptr.");
46 return OH_NN_INVALID_PARAMETER;
47 }
48
49 int8_t* fuseData = static_cast<int8_t*>(buffer);
50 if (!OHOS::NeuralNetworkRuntime::Validation::ValidateFuseType(static_cast<OH_NN_FuseType>(*fuseData))) {
51 LOGE("[SubBuilder] Fuse activation type is invalid");
52 return OH_NN_INVALID_PARAMETER;
53 }
54
55 auto fuseType = (OH_NN_FuseType)(*fuseData);
56 m_activationType = NNToMS::TransfromFusionType(fuseType);
57 return OH_NN_SUCCESS;
58 }
59
60 /**
61 * Build method.
62 * 1.set attr of ops.
63 * 2.set inputIndex of ops.
64 * 3.set outputIndex of ops.
65 */
Build(const std::vector<uint32_t> & paramsIndex,const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)66 OH_NN_ReturnCode SubBuilder::Build(const std::vector<uint32_t>& paramsIndex,
67 const std::vector<uint32_t>& inputsIndex,
68 const std::vector<uint32_t>& outputsIndex,
69 const std::vector<std::shared_ptr<NNTensor>>& allTensors)
70 {
71 if (m_isBuild) {
72 LOGE("[SubBuilder] Sub operation has been build, cannot build again.");
73 return OH_NN_OPERATION_FORBIDDEN;
74 }
75
76 OH_NN_ReturnCode returnCode = CheckIOIndex(inputsIndex, outputsIndex, allTensors, INPUT_NUM, OUTPUT_NUM);
77 if (returnCode != OH_NN_SUCCESS) {
78 LOGE("[SubBuilder] Passed invalid input or output index.");
79 return returnCode;
80 }
81
82 m_inputsIndex = inputsIndex;
83 m_outputsIndex = outputsIndex;
84
85 for (int i : paramsIndex) {
86 std::shared_ptr<NNTensor> tensor = allTensors[i];
87 tensor->IdentifyOpParameter();
88 switch (tensor->GetType()) {
89 case OH_NN_SUB_ACTIVATIONTYPE:
90 returnCode = SetActivationType(tensor);
91 break;
92 default:
93 LOGE("[SubBuilder] Parameter Type is invalid. type=%d", tensor->GetType());
94 return OH_NN_INVALID_PARAMETER;
95 }
96
97 if (returnCode != OH_NN_SUCCESS) {
98 LOGE("[SubBuilder] Passed invalid param.");
99 return returnCode;
100 }
101 }
102
103 // The quantization type of the first output determinies that of the operator.
104 SetQuantType(outputsIndex, allTensors);
105
106 m_isBuild = true;
107 m_name = "Sub";
108 return OH_NN_SUCCESS;
109 }
110
GetPrimitive()111 LiteGraphPrimitvePtr SubBuilder::GetPrimitive()
112 {
113 if (!m_isBuild) {
114 LOGE("[SubBuilder] Cannot get primitive before call build.");
115 return {nullptr, DestroyLiteGraphPrimitive};
116 }
117
118 auto primitive = mindspore::lite::MindIR_SubFusion_CreatePrimitive(m_activationType);
119 if (primitive == nullptr) {
120 LOGE("[SubBuilder] MindIR_SubFusion_CreatePrimitive failed.");
121 return {nullptr, DestroyLiteGraphPrimitive};
122 }
123
124 LiteGraphPrimitvePtr graphPrimitivePtr(primitive, DestroyLiteGraphPrimitive);
125 return graphPrimitivePtr;
126 }
127
128 REGISTER_OPS(SubBuilder, OH_NN_OPS_SUB);
129 } // namespace Ops
130 } // namespace NeuralNetworkRuntime
131 } // namespace OHOS
132