1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_ADAPTER_DPICO_COMMON_ANF_UTIL_H_
18 #define MINDSPORE_LITE_TOOLS_CONVERTER_ADAPTER_DPICO_COMMON_ANF_UTIL_H_
19
20 #include <vector>
21 #include <string>
22 #include "mindapi/ir/tensor.h"
23 #include "include/errorcode.h"
24 #include "mindapi/base/logging.h"
25 #include "mindapi/ir/anf.h"
26 #include "mindapi/ir/func_graph.h"
27
28 using mindspore::lite::RET_ERROR;
29 using mindspore::lite::RET_NO_CHANGE;
30 using mindspore::lite::RET_OK;
31 using mindspore::lite::STATUS;
32
33 namespace mindspore {
34 namespace dpico {
35 bool CheckPrimitiveType(const api::AnfNodePtr &node, const api::PrimitivePtr &primitive_type);
36 STATUS GetPrimitiveType(const api::AnfNodePtr &node, std::string *name);
37 STATUS GetShapeVectorFromParameter(const api::AnfNodePtr &weight, ShapeVector *shape_vector);
38 std::vector<int> CastToInt(const api::ValuePtr &value);
39 size_t GetTupleGetItemOutIndex(const api::CNodePtr &tuple_get_item);
40 STATUS GetOutputShapesFromCNode(const api::CNodePtr &cnode, std::vector<ShapeVector> *output_shapes);
41 STATUS GetInputShapeFromCNode(const api::CNodePtr &cnode, size_t input_idx, ShapeVector *shape);
42 STATUS FetchShapeFromAbstract(const api::AbstractBasePtr &abstract, ShapeVector *shape);
43 STATUS FetchTypeIdFromAbstract(const api::AbstractBasePtr &abstract, TypeId *type_id);
44 int GetAnfNodeOutputShape(const api::AnfNodePtr &input, ShapeVector *shape_vector);
45 std::string TypeIdToString(TypeId type_id);
46 bool CheckInputs(const api::CNodePtr &cnode);
47 std::string GetCustomOutputName(const api::AnfNodePtr &node);
48 api::TensorPtr CreateTensorInfo(const void *data, size_t data_size, const std::vector<int64_t> &shape,
49 TypeId data_type);
50 api::AbstractBasePtr CreateTensorAbstract(const std::vector<int64_t> &shape, TypeId data_type);
51 int InitParameterFromTensorInfo(const api::ParameterPtr ¶m_node, const api::TensorPtr &tensor_info);
52 api::AbstractBasePtr GetCNodeInputAbstract(const api::CNodePtr &cnode, size_t index);
53 api::AbstractBasePtr GetAbstractFromAnfNode(const api::AnfNodePtr &cnode);
54 api::ParameterPtr BuildIntValueParameterNode(const api::FuncGraphPtr &func_graph, const int32_t &data,
55 const std::string &node_name);
56 api::ParameterPtr BuildIntVecParameterNode(const api::FuncGraphPtr &func_graph, const std::vector<int32_t> &data,
57 const std::string &node_name);
58 api::ParameterPtr BuildIntVec2DParameterNode(const api::FuncGraphPtr &func_graph,
59 const std::vector<std::vector<int32_t>> &data,
60 const std::string &node_name);
61 api::ParameterPtr BuildFloatValueParameterNode(const api::FuncGraphPtr &func_graph, const float &data,
62 const std::string &node_name);
63 api::CNodePtr GenTransposeNode(const api::FuncGraphPtr &func_graph, const api::AnfNodePtr &input_node,
64 const std::vector<int> &perm, const std::string &cnode_name);
65 api::TensorPtr GetTensorInfo(const api::AnfNodePtr &node);
66 std::vector<std::vector<int>> CastToVec2DInt(const api::ValuePtr &value);
67 bool GetBoolAttr(const api::AnfNodePtr &node, const std::string &attr_name);
68 STATUS GetDataTypeAndShape(const api::ParameterPtr ¶m_node, TypeId *data_type, ShapeVector *shape_vector);
69 STATUS GetShapeVectorFromStringTensor(const api::TensorPtr &tensor_info, ShapeVector *shape_vector, size_t *offset);
IntToSize(int u)70 inline size_t IntToSize(int u) {
71 if (u < 0) {
72 MS_LOG(WARNING) << "The int value(" << u << ") is less than 0.";
73 return SIZE_MAX;
74 }
75 return static_cast<size_t>(u);
76 }
77 } // namespace dpico
78 } // namespace mindspore
79
80 #endif // MINDSPORE_LITE_TOOLS_CONVERTER_ADAPTER_DPICO_COMMON_ANF_UTIL_H_
81