1 /** 2 * Copyright 2019-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_CONVERT_H_ 18 #define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_CONVERT_H_ 19 20 #define DRAW_GE_GRAPH 21 22 #include <memory> 23 #include <map> 24 #include <set> 25 #include <vector> 26 #include <unordered_map> 27 #include <string> 28 #include <utility> 29 #include <stack> 30 #include <fstream> 31 #include <sstream> 32 33 #include "ir/anf.h" 34 #include "ir/func_graph.h" 35 #include "transform/graph_ir/util.h" 36 #include "ir/tensor.h" 37 #include "transform/graph_ir/df_graph_manager.h" 38 #include "utils/config_manager.h" 39 #include "transform/graph_ir/op_adapter.h" 40 #include "graph/operator_reg.h" 41 #ifdef OPEN_SOURCE 42 #include "ge/client/ge_api.h" 43 #else 44 #include "external/ge/ge_api.h" 45 #endif 46 #include "graph/tensor.h" 47 #include "ops/hcom_ops.h" 48 49 namespace mindspore { 50 namespace transform { 51 using TensorOrderMap = std::map<std::string, std::shared_ptr<tensor::Tensor>>; 52 using HcomBroadcast = ge::op::HcomBroadcast; 53 class DfGraphConvertor { 54 public: DfGraphConvertor(const AnfGraphPtr & anf_graph)55 explicit DfGraphConvertor(const AnfGraphPtr &anf_graph) : anf_graph_(anf_graph) { 56 MS_EXCEPTION_IF_NULL(anf_graph); 57 df_graph_ = std::make_shared<DfGraph>(anf_graph_->ToString()); 58 #if (!defined ENABLE_GE) || (defined ENABLE_INFER) 59 training_ = anf_graph->has_flag("training"); 60 #else 61 training_ = ENABLE_TRAIN; 62 #endif 63 distribute_ = anf_graph->has_flag("broadcast_flag"); 64 if (anf_graph->has_flag("broadcast_flag")) { 65 ConfigManager::GetInstance().set_parallel_strategy(ParallelStrategy::DISTRIBUTION); 66 } else { 67 ConfigManager::GetInstance().set_parallel_strategy(ParallelStrategy::ONE_DEVICE); 68 } 69 70 MS_LOG(INFO) << "Create DfGraphConvertor with training: " << training_ << ", distribute: " << distribute_; 71 } 72 ~DfGraphConvertor()73 ~DfGraphConvertor() {} 74 75 static void RegisterAdapter(const std::string &name, OpAdapterPtr adpt); 76 static void RegisterAdapter(const std::string &name, OpAdapterPtr train_adpt, OpAdapterPtr infer_adpt); 77 DrawComputeGraph(const std::string & name)78 void DrawComputeGraph(const std::string &name) { 79 #ifndef ENABLE_SECURITY 80 std::ofstream fout(name); 81 if (!fout.is_open()) { 82 MS_LOG(ERROR) << "Open file '" << name << "' failed!"; 83 return; 84 } 85 fout << compute_sout_.str(); 86 fout.close(); 87 #endif 88 } 89 DrawInitGraph(const std::string & name)90 void DrawInitGraph(const std::string &name) { 91 #ifndef ENABLE_SECURITY 92 std::ofstream fout(name); 93 if (!fout.is_open()) { 94 MS_LOG(ERROR) << "Open file '" << name << "' failed!"; 95 return; 96 } 97 fout << init_sout_.str(); 98 fout.close(); 99 #endif 100 } DrawSaveCheckpointGraph(const std::string & name)101 void DrawSaveCheckpointGraph(const std::string &name) { 102 std::ofstream fout(name); 103 if (!fout.is_open()) { 104 MS_LOG(ERROR) << "Open file '" << name << "' failed!"; 105 return; 106 } 107 fout << checkpoint_sout_.str(); 108 fout.close(); 109 } 110 111 DfGraphConvertor &ConvertAllNode(); 112 DfGraphConvertor &BuildGraph(); 113 DfGraphConvertor &InitParam(const TensorOrderMap &tensors); 114 DfGraphConvertor &GenerateCheckpointGraph(); 115 DfGraphConvertor &GenerateBroadcastGraph(const TensorOrderMap &tensors); 116 void InitParamWithData(const TensorOrderMap &tensors); 117 void SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node); 118 void SetupBroadcast(const std::shared_ptr<HcomBroadcast> &broadcast, const std::vector<GeTensorDesc> &broadcast_desc, 119 const DfGraphPtr &broadcast_graph, std::vector<ge::Operator> broadcast_input); 120 void MakeDatasetHandler(const std::string &name, const size_t &input_idx, const AnfNodePtr &it); 121 void SetupParamInitSubGraph(const TensorOrderMap &tensors, std::vector<ge::Operator> *init_input); 122 void DrawParamInitSubGraph(const std::string &name, const AnfNodePtr &it); 123 124 DfGraphPtr GetComputeGraph(); 125 DfGraphPtr GetInitGraph(); 126 DfGraphPtr GetSaveCheckpointGraph(); 127 DfGraphPtr GetBroadcastGraph(); 128 static OpAdapterPtr FindAdapter(const std::string &op_name, bool train = false); 129 static OpAdapterPtr FindAdapter(AnfNodePtr node, bool train = false); ErrCode()130 int ErrCode() const { return static_cast<int>(error_); } 131 is_training()132 bool is_training() const { return training_; } set_training(bool is_training)133 void set_training(bool is_training) { training_ = is_training; } 134 135 protected: 136 void InitLoopVar(std::vector<ge::Operator> *init_input); 137 138 private: 139 std::ostringstream compute_sout_; 140 std::ostringstream init_sout_; 141 std::ostringstream checkpoint_sout_; 142 std::ostringstream restore_checkpoint_sout_; 143 std::unordered_map<AnfNode *, std::string> op_draw_name_; 144 std::map<std::string, std::string> param_format_; 145 146 AnfNodePtr TraceTupleGetItem(const CNodePtr &node, uint64_t *index); 147 AnfNodePtr TraceMakeTuple(const CNodePtr &node, uint64_t index); 148 AnfNodePtr TraceDepend(const CNodePtr &node); 149 OutHandler TraceRealOp(AnfNodePtr node); 150 OutHandler GetHandler(const AnfNodePtr &node, const std::stack<uint64_t> &index_stack, AnfNode *const draw_index); 151 OperatorPtr Convert(AnfNodePtr node); 152 OperatorPtr ConvertCNode(CNodePtr node); 153 std::vector<OperatorPtr> ConvertDependNode(AnfNodePtr node); 154 AnfNodePtr GetRealOpNode(AnfNodePtr node); 155 OperatorPtr ConvertParameter(AnfNodePtr node); 156 Status TryConvertValueNodeToMultiConst(const ValueNodePtr node); 157 OperatorPtr ConvertValueNode(ValueNodePtr node); 158 void SaveParamFormat(CNodePtr node); 159 void GetCaseNodeInput(const CNodePtr node, const CNodePtr input_node); 160 void ConvertTupleGetItem(const CNodePtr node); 161 void ConvertMakeTuple(const CNodePtr node); 162 void ConvertTopK(const CNodePtr node); 163 void ConvertReshape(const CNodePtr node); 164 std::vector<int64_t> CastToInt(const ValuePtr &value); 165 bool CheckCNode(const std::string &name, const CNodePtr node); 166 void TraceOutput(AnfNodePtr node); 167 void TraceOutputFromParameter(const AnfNodePtr &anf_out); 168 void TraceOutputFromTupleGetItem(const AnfNodePtr &anf_out); 169 void SetNodeInput(AnfNodePtr node); 170 void SetOpControlInput(const AnfNodePtr &node); 171 void UpdateOpDesc(AnfNodePtr node); 172 void SetSubgraph(AnfNodePtr node); 173 void ProcessSubgraph(AnfNodePtr node, const std::vector<AnfNodePtr> &inputs); 174 void BuildSaveCheckpointGraph(); 175 void DrawCNode(const CNodePtr node, const OpAdapterPtr adpt); 176 void UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const; 177 void AddGraphConstInput(const OperatorPtr &op); 178 OperatorPtr ToOperatorPtr(const AnfNodePtr &node); 179 bool IsSourceEdgeNode(const AnfNodePtr &node); 180 bool IsControlEdgeNode(const AnfNodePtr &node); 181 void AddEdgeForLoad(const AnfNodePtr &node); 182 void AddEdgeToCache(const AnfNodePtr &src, const AnfNodePtr &dest); 183 void FindDestOps(const AnfNodePtr &node, const std::shared_ptr<std::vector<AnfNodePtr>> &node_list, bool top); 184 AnfNodePtr ParseLoadInput(const CNodePtr &cnode); 185 void AutoMonadSetControlInput(const AnfNodePtr &node); 186 void AutoMonadCollectInput(const AnfNodePtr &node); 187 void AutoMonadSetInput(const AnfNodePtr &node); 188 void SetTupleOpInput(const OpAdapterPtr &adpt, const CNodePtr &node, const AnfNodePtr &pred, const OperatorPtr &src, 189 int index); 190 void UpdateTupleOutCache(void); 191 AnfNodePtr GetRealInputNode(const CNodePtr &node, const AnfNodePtr &input); 192 193 std::shared_ptr<AnfGraph> anf_graph_{nullptr}; 194 std::shared_ptr<DfGraph> df_graph_{nullptr}; 195 std::shared_ptr<DfGraph> init_graph_{nullptr}; 196 std::shared_ptr<DfGraph> save_ckp_graph_{nullptr}; 197 std::shared_ptr<DfGraph> restore_ckp_graph_{nullptr}; 198 std::shared_ptr<DfGraph> broadcast_graph_{nullptr}; 199 std::unordered_map<AnfNode *, DfGraph> branches_map_; 200 std::unordered_map<AnfNode *, OperatorPtr> op_cache_; 201 std::unordered_map<AnfNode *, std::vector<ControlEdge>> control_edge_cache_; 202 std::unordered_map<AnfNodePtr, std::set<AnfNodePtr>> monad_control_edge_cache_; 203 /* record "tuple_getitem"<->"out_handler" mapping */ 204 std::unordered_map<AnfNode *, OutHandler> out_handle_cache_; 205 /* record "make_tuple"<->"out_handler vector" mapping */ 206 std::unordered_map<AnfNode *, std::shared_ptr<std::vector<OutHandler>>> tuple_out_handle_cache_; 207 std::unordered_map<AnfNode *, std::shared_ptr<std::vector<AnfNodePtr>>> case_input_handle_cache_; 208 std::unordered_map<std::string, AnfNodePtr> params_; 209 std::unordered_map<std::string, OperatorPtr> vars_; 210 std::vector<std::pair<ge::Operator, std::string>> graph_outputs_; 211 std::vector<OperatorPtr> graph_const_inputs_; 212 std::vector<OperatorPtr> init_ops_; 213 std::vector<OperatorPtr> broadcast_ops_; 214 std::vector<AnfNodePtr> inputs_; 215 OperatorPtr dataset_iter_getnext_; 216 Status error_ = SUCCESS; 217 bool training_ = false; 218 bool distribute_ = false; 219 bool use_inputs_ = false; 220 }; 221 } // namespace transform 222 } // namespace mindspore 223 224 #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_CONVERT_H_ 225