• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2022 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_LITE_SRC_EXTENDRT_UTILS_FUNC_GRAPH_UTILS_H_
20 #define MINDSPORE_LITE_SRC_EXTENDRT_UTILS_FUNC_GRAPH_UTILS_H_
21 
22 #include <utility>
23 #include <string>
24 #include <vector>
25 #include <tuple>
26 
27 #include "ir/anf.h"
28 #include "ir/dtype/type.h"
29 #include "ir/func_graph.h"
30 #include "include/api/data_type.h"
31 #include "include/api/status.h"
32 #include "mindspore/ccsrc/kernel/kernel.h"
33 #include "include/common/utils/anfalgo.h"
34 
35 namespace mindspore {
36 using AnfWithOutIndex = std::pair<AnfNodePtr, size_t>;
37 using kernel::BaseOperatorPtr;
38 
39 class FuncGraphUtils {
40  public:
41   static tensor::TensorPtr GetConstNodeValue(AnfNodePtr input_node);
42   static std::vector<common::KernelWithIndex> GetNodeInputs(const AnfNodePtr &anf_node);
43 
44   static bool GetCNodeOperator(const CNodePtr &cnode, BaseOperatorPtr *base_operator);
45 
46   static bool GetCNodeInputsOutputs(const CNodePtr &cnode, std::vector<AnfWithOutIndex> *input_tensors,
47                                     std::vector<AnfWithOutIndex> *output_tensors);
48   static bool GetFuncGraphInputs(const FuncGraphPtr &func_graph, std::vector<AnfWithOutIndex> *inputs);
49   static bool GetFuncGraphOutputs(const FuncGraphPtr &func_graph, std::vector<AnfWithOutIndex> *outputs);
50 
51   static DataType GetTensorDataType(const AnfWithOutIndex &tensor);
52   static ShapeVector GetTensorShape(const AnfWithOutIndex &tensor);
53   static std::string GetTensorName(const AnfWithOutIndex &tensor);
54   static AbstractBasePtr GetAbstract(const AnfWithOutIndex &tensor);
55 
56   static void GetFuncGraphInputsInfo(const FuncGraphPtr &graph, std::vector<tensor::TensorPtr> *inputs,
57                                      std::vector<std::string> *inputs_name);
58   static void GetFuncGraphOutputsInfo(const FuncGraphPtr &graph, std::vector<tensor::TensorPtr> *outputs,
59                                       std::vector<std::string> *output_names);
60   static Status UnifyGraphToNHWCFormat(const FuncGraphPtr &graph);
61 
62   static std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGraph(const AnfNodePtrList &lst);
63 
64   static AnfNodePtrList GetOutput(const AnfNodePtrList &nodes, const NodeUsersMap &users,
65                                   const mindspore::HashSet<AnfNodePtr> &seen);
66   static AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNodePtrList *inputs_ptr,
67                                     mindspore::HashMap<AnfNodePtr, AnfNodePtr> *eqv_ptr);
68 
69  private:
70   static ValuePtr GetNodeValuePtr(AnfNodePtr input_node);
71 };
72 }  // namespace mindspore
73 
74 #endif  // MINDSPORE_LITE_SRC_EXTENDRT_UTILS_FUNC_GRAPH_UTILS_H_
75