• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "sub_builder.h"
17 #include "frameworks/native/transform.h"
18 #include "frameworks/native/validation.h"
19 
20 namespace OHOS {
21 namespace NeuralNetworkRuntime {
22 namespace Ops {
23 static const int INPUT_NUM = 2;
24 static const int OUTPUT_NUM = 1;
25 static const std::string OP_NAME = "Sub";
26 
SubBuilder()27 SubBuilder::SubBuilder() {}
28 
~SubBuilder()29 SubBuilder::~SubBuilder() {}
30 
SetActivationType(std::shared_ptr<NNTensor> tensor)31 OH_NN_ReturnCode SubBuilder::SetActivationType(std::shared_ptr<NNTensor> tensor)
32 {
33     if (tensor->GetDataType() != OH_NN_INT8) {
34         LOGE("[SubBuilder] The 3rd input activation should be type OH_NN_INT8.");
35         return OH_NN_INVALID_PARAMETER;
36     }
37 
38     if (tensor->GetElementCount() != 1) {
39         LOGE("[SubBuilder] The 3rd input activation should be scaler.");
40         return OH_NN_INVALID_PARAMETER;
41     }
42 
43     void* buffer = tensor->GetBuffer();
44     if (buffer == nullptr) {
45         LOGE("[SubBuilder] Tensor buffer is nullptr.");
46         return OH_NN_INVALID_PARAMETER;
47     }
48 
49     int8_t* fuseData = static_cast<int8_t*>(buffer);
50     if (!OHOS::NeuralNetworkRuntime::Validation::ValidateFuseType(static_cast<OH_NN_FuseType>(*fuseData))) {
51         LOGE("[SubBuilder] Fuse activation type is invalid");
52         return OH_NN_INVALID_PARAMETER;
53     }
54 
55     auto fuseType = (OH_NN_FuseType)(*fuseData);
56     m_activationType = NNToMS::TransfromFusionType(fuseType);
57     return OH_NN_SUCCESS;
58 }
59 
60 /**
61  * Build method.
62  * 1.set attr of ops.
63  * 2.set inputIndex of ops.
64  * 3.set outputIndex of ops.
65  */
Build(const std::vector<uint32_t> & paramsIndex,const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)66 OH_NN_ReturnCode SubBuilder::Build(const std::vector<uint32_t>& paramsIndex,
67                                    const std::vector<uint32_t>& inputsIndex,
68                                    const std::vector<uint32_t>& outputsIndex,
69                                    const std::vector<std::shared_ptr<NNTensor>>& allTensors)
70 {
71     if (m_isBuild) {
72         LOGE("[SubBuilder] Sub operation has been build, cannot build again.");
73         return OH_NN_OPERATION_FORBIDDEN;
74     }
75 
76     OH_NN_ReturnCode returnCode = CheckIOIndex(inputsIndex, outputsIndex, allTensors, INPUT_NUM, OUTPUT_NUM);
77     if (returnCode != OH_NN_SUCCESS) {
78         LOGE("[SubBuilder] Passed invalid input or output index.");
79         return returnCode;
80     }
81 
82     m_inputsIndex = inputsIndex;
83     m_outputsIndex = outputsIndex;
84 
85     for (int i : paramsIndex) {
86         std::shared_ptr<NNTensor> tensor = allTensors[i];
87         tensor->IdentifyOpParameter();
88         switch (tensor->GetType()) {
89             case OH_NN_SUB_ACTIVATIONTYPE:
90                 returnCode = SetActivationType(tensor);
91                 break;
92             default:
93                 LOGE("[SubBuilder] Parameter Type is invalid. type=%d", tensor->GetType());
94                 return OH_NN_INVALID_PARAMETER;
95         }
96 
97         if (returnCode != OH_NN_SUCCESS) {
98             LOGE("[SubBuilder] Passed invalid param.");
99             return returnCode;
100         }
101     }
102 
103     // The quantization type of the first output determinies that of the operator.
104     SetQuantType(outputsIndex, allTensors);
105 
106     m_isBuild = true;
107     m_name = "Sub";
108     return OH_NN_SUCCESS;
109 }
110 
GetPrimitive()111 LiteGraphPrimitvePtr SubBuilder::GetPrimitive()
112 {
113     if (!m_isBuild) {
114         LOGE("[SubBuilder] Cannot get primitive before call build.");
115         return {nullptr, DestroyLiteGraphPrimitive};
116     }
117 
118     auto primitive = mindspore::lite::MindIR_SubFusion_CreatePrimitive(m_activationType);
119     if (primitive == nullptr) {
120         LOGE("[SubBuilder] MindIR_SubFusion_CreatePrimitive failed.");
121         return {nullptr, DestroyLiteGraphPrimitive};
122     }
123 
124     LiteGraphPrimitvePtr graphPrimitivePtr(primitive, DestroyLiteGraphPrimitive);
125     return graphPrimitivePtr;
126 }
127 
128 REGISTER_OPS(SubBuilder, OH_NN_OPS_SUB);
129 } // namespace Ops
130 } // namespace NeuralNetworkRuntime
131 } // namespace OHOS
132