• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_GENERATE_GRAPH_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_GENERATE_GRAPH_H_
19 
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <vector>
24 #include <utility>
25 
26 #include "ir/anf.h"
27 #include "ir/primitive.h"
28 #include "ops/op_def.h"
29 #include "utils/hash_map.h"
30 #include "frontend/optimizer/opt.h"
31 #include "frontend/parallel/strategy.h"
32 #include "frontend/parallel/tensor_layout/tensor_redistribution.h"
33 
34 namespace mindspore {
35 namespace parallel {
36 const char USING_HASH_NAME[] = "USING_HASH_NAME";
37 std::pair<bool, size_t> CheckAndGetValidIdxByOpDef(const ops::OpDefPtr &op_def, const std::string &op_name,
38                                                    const std::string &attr_name, size_t limit_size);
39 // Get the operator's path where the operator has be defined
40 const char *GetOpPythonPath(const char *op_name);
41 
42 // Init python operator Instance
43 ValuePtr CreateOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name);
44 std::vector<AnfNodePtr> ConvertToRealInputs(const OperatorName &op_name, const std::string &instance_name,
45                                             const AnfNodePtrList &inputs, const OperatorAttrs &attrs);
46 CNodePtr CreateCNodeByInputsAndAttr(const FuncGraphPtr &func_graph, const OperatorName &op_name,
47                                     const std::string &instance_name, const AnfNodePtrList &inputs,
48                                     const OperatorAttrs &attrs);
49 CNodePtr CreateNewCNodeForReplace(const CNodePtr &origin_node, const PrimitivePtr &new_prim);
50 
51 AnfNodePtr CreateTypeInt(int64_t nbits);
52 AnfNodePtr CreateTypeFloat(int64_t nbits);
53 AnfNodePtr CreatInt64Imm(int64_t value);
54 AnfNodePtr CreateFP32Imm(float value);
55 AnfNodePtr CreateBoolImm(bool value);
56 AnfNodePtr CreateInt32Tensor(int64_t value, bool int64_type = false);
57 AnfNodePtr CreateFP32Tensor(float value);
58 AnfNodePtr CreateStringImm(std::string value);
59 AnfNodePtr ValuePtrToAnfNodePtr(const ValuePtr &value_ptr);
60 AnfNodePtr CreateTuple(const std::vector<int64_t> &tuple);
61 std::string HashInstanceName(const std::string &name);
62 void InsertVirtualPipelineEndNode(const CNodePtr &cnode, const FuncGraphManagerPtr &manager, size_t index,
63                                   std::string end_flag = "pipeline_end");
64 CNodePtr CreateVirtualConverterBeginNode(const CNodePtr &input_cnode, size_t output_nums);
65 CNodePtr CreateVirtualConverterEndNode(const FuncGraphPtr &graph, const std::vector<CNodePtr> &input_cnodes);
66 
67 class GenerateGraph {
68  public:
GenerateGraph(const mindspore::HashMap<std::string,ValuePtr> & origin_attrs)69   explicit GenerateGraph(const mindspore::HashMap<std::string, ValuePtr> &origin_attrs)
70       : name_idx_(0), origin_attrs_(origin_attrs) {}
71   Status Init(const CNodePtr &cnode);
72   ~GenerateGraph() = default;
virtual_input_node()73   AnfNodePtr virtual_input_node() { return virtual_input_node_; }
74   AnfNodePtr NewOpInst(const OperatorName &op_name, const OperatorAttrs &attrs);
75   AnfNodePtr NewOpInst(const OperatorName &op_name);
76   AnfNodePtr PushBack(const std::vector<AnfNodePtr> &inputs);
77 
78  private:
79   CNodePtr cnode_;
80   FuncGraphManagerPtr manager_;
81   ScopePtr scope_;
82   FuncGraphPtr func_graph_;
83   AnfNodePtr virtual_input_node_;
84   std::string instance_name_base_;
85   int64_t name_idx_;
86   mindspore::HashMap<std::string, ValuePtr> origin_attrs_;
87 };
88 }  // namespace parallel
89 }  // namespace mindspore
90 
91 #endif  // MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_GENERATE_GRAPH_H_
92