• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "reduceall_builder.h"
17 
18 #include "frameworks/native/ops_registry.h"
19 
20 namespace OHOS {
21 namespace NeuralNetworkRuntime {
22 namespace Ops {
23 static const int INPUT_NUM = 2;
24 static const int OUTPUT_NUM = 1;
25 static const int SCALE_LENGTH = 1;
26 static const std::string OP_NAME = "ReduceAll";
27 
ReduceAllBuilder()28 ReduceAllBuilder::ReduceAllBuilder() {}
29 
~ReduceAllBuilder()30 ReduceAllBuilder:: ~ReduceAllBuilder() {}
31 
SetKeepDims(std::shared_ptr<NNTensor> tensor)32 OH_NN_ReturnCode ReduceAllBuilder::SetKeepDims(std::shared_ptr<NNTensor> tensor)
33 {
34     tensor->IdentifyOpParameter();
35     if (tensor->GetElementCount() != SCALE_LENGTH) {
36         LOGE("[ReduceAll] SetKeepDims failed, the keep_dims dimensions should be scalar.");
37         return OH_NN_INVALID_PARAMETER;
38     }
39 
40     if (tensor->GetDataType() != OH_NN_BOOL) {
41         LOGE("[ReduceAll] SetKeepDims failed, the keep_dims should be type OH_NN_BOOL.");
42         return OH_NN_INVALID_PARAMETER;
43     }
44 
45     void* buffer = tensor->GetBuffer();
46     if (buffer == nullptr) {
47         LOGE("[ReduceAll] SetKeepDims failed, the keep_dims passed buffer is empty.");
48         return OH_NN_INVALID_PARAMETER;
49     }
50 
51     m_keepDims = *(static_cast<bool*>(buffer));
52     return OH_NN_SUCCESS;
53 }
54 
Build(const std::vector<uint32_t> & paramsIndex,const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)55 OH_NN_ReturnCode ReduceAllBuilder::Build(const std::vector<uint32_t>& paramsIndex,
56                                          const std::vector<uint32_t>& inputsIndex,
57                                          const std::vector<uint32_t>& outputsIndex,
58                                          const std::vector<std::shared_ptr<NNTensor>>& allTensors)
59 {
60     if (m_isBuild) {
61         LOGE("[ReduceAll] Build failed, the ReduceAll operation has been build, cannot build again.");
62         return OH_NN_OPERATION_FORBIDDEN;
63     }
64 
65     OH_NN_ReturnCode returnCode = CheckIOIndex(inputsIndex, outputsIndex, allTensors, INPUT_NUM, OUTPUT_NUM);
66     if (returnCode != OH_NN_SUCCESS) {
67         LOGE("[ReduceAll] Build failed, passed invalid input or output index of ReduceAll operation index.");
68         return returnCode;
69     }
70 
71     m_inputsIndex = inputsIndex;
72     m_outputsIndex = outputsIndex;
73 
74     for (uint32_t i : paramsIndex) {
75         std::shared_ptr<NNTensor> tensor = allTensors[i];
76         switch (tensor->GetType()) {
77             case OH_NN_REDUCE_ALL_KEEP_DIMS:
78                 returnCode = SetKeepDims(tensor);
79                 break;
80             default:
81                 LOGE("[ReduceAll] Build failed, parameter type is invalid. type=%d", tensor->GetType());
82                 return OH_NN_INVALID_PARAMETER;
83         }
84 
85         if (returnCode != OH_NN_SUCCESS) {
86             LOGE("[ReduceAll] Build failed, passed invalid param.");
87             return returnCode;
88         }
89     }
90 
91     m_name = OP_NAME;
92     m_isBuild = true;
93     return OH_NN_SUCCESS;
94 }
95 
GetPrimitive()96 LiteGraphPrimitvePtr ReduceAllBuilder::GetPrimitive()
97 {
98     if (!m_isBuild) {
99         LOGE("[ReduceAll] GetPrimitive failed, cannot get primitive before call build.");
100         return {nullptr, DestroyLiteGraphPrimitive};
101     }
102 
103     bool reduceToEnd{false};
104     float coeff{0.0f};
105 
106     void* primitive = mindspore::lite::MindIR_ReduceFusion_CreatePrimitive(m_keepDims, m_mode, reduceToEnd, coeff);
107     LiteGraphPrimitvePtr graphPrimitivePtr(primitive, DestroyLiteGraphPrimitive);
108     return graphPrimitivePtr;
109 }
110 
111 REGISTER_OPS(ReduceAllBuilder, OH_NN_OPS_REDUCE_ALL);
112 } // namespace Ops
113 } // namespace NeuralNetworkRuntime
114 } // namespace OHOS