1 /** 2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 3 * 4 * Copyright 2022 Huawei Technologies Co., Ltd 5 * 6 * Licensed under the Apache License, Version 2.0 (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, software 13 * distributed under the License is distributed on an "AS IS" BASIS, 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 * See the License for the specific language governing permissions and 16 * limitations under the License. 17 */ 18 19 #ifndef MINDSPORE_LITE_SRC_EXTENDRT_UTILS_FUNC_GRAPH_UTILS_H_ 20 #define MINDSPORE_LITE_SRC_EXTENDRT_UTILS_FUNC_GRAPH_UTILS_H_ 21 22 #include <utility> 23 #include <string> 24 #include <vector> 25 #include <tuple> 26 27 #include "ir/anf.h" 28 #include "ir/dtype/type.h" 29 #include "ir/func_graph.h" 30 #include "include/api/data_type.h" 31 #include "include/api/status.h" 32 #include "mindspore/ccsrc/kernel/kernel.h" 33 #include "include/common/utils/anfalgo.h" 34 35 namespace mindspore { 36 using AnfWithOutIndex = std::pair<AnfNodePtr, size_t>; 37 using kernel::BaseOperatorPtr; 38 39 class FuncGraphUtils { 40 public: 41 static tensor::TensorPtr GetConstNodeValue(AnfNodePtr input_node); 42 static std::vector<common::KernelWithIndex> GetNodeInputs(const AnfNodePtr &anf_node); 43 44 static bool GetCNodeOperator(const CNodePtr &cnode, BaseOperatorPtr *base_operator); 45 46 static bool GetCNodeInputsOutputs(const CNodePtr &cnode, std::vector<AnfWithOutIndex> *input_tensors, 47 std::vector<AnfWithOutIndex> *output_tensors); 48 static bool GetFuncGraphInputs(const FuncGraphPtr &func_graph, std::vector<AnfWithOutIndex> *inputs); 49 static bool GetFuncGraphOutputs(const FuncGraphPtr &func_graph, std::vector<AnfWithOutIndex> *outputs); 50 51 static DataType GetTensorDataType(const AnfWithOutIndex &tensor); 52 static ShapeVector GetTensorShape(const AnfWithOutIndex &tensor); 53 static std::string GetTensorName(const AnfWithOutIndex &tensor); 54 static AbstractBasePtr GetAbstract(const AnfWithOutIndex &tensor); 55 56 static void GetFuncGraphInputsInfo(const FuncGraphPtr &graph, std::vector<tensor::TensorPtr> *inputs, 57 std::vector<std::string> *inputs_name); 58 static void GetFuncGraphOutputsInfo(const FuncGraphPtr &graph, std::vector<tensor::TensorPtr> *outputs, 59 std::vector<std::string> *output_names); 60 static Status UnifyGraphToNHWCFormat(const FuncGraphPtr &graph); 61 62 static std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGraph(const AnfNodePtrList &lst); 63 64 static AnfNodePtrList GetOutput(const AnfNodePtrList &nodes, const NodeUsersMap &users, 65 const mindspore::HashSet<AnfNodePtr> &seen); 66 static AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNodePtrList *inputs_ptr, 67 mindspore::HashMap<AnfNodePtr, AnfNodePtr> *eqv_ptr); 68 69 private: 70 static ValuePtr GetNodeValuePtr(AnfNodePtr input_node); 71 }; 72 } // namespace mindspore 73 74 #endif // MINDSPORE_LITE_SRC_EXTENDRT_UTILS_FUNC_GRAPH_UTILS_H_ 75