1 /** 2 * Copyright 2019-2024 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_CONVERT_H_ 18 #define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_CONVERT_H_ 19 20 #define DRAW_GE_GRAPH 21 22 #include <cstdlib> 23 #include <memory> 24 #include <map> 25 #include <set> 26 #include <unordered_set> 27 #include <vector> 28 #include <string> 29 #include <utility> 30 #include <stack> 31 #include <fstream> 32 #include <sstream> 33 #include "include/common/utils/config_manager.h" 34 #include "mindspore/core/ops/structure_ops.h" 35 #include "utils/hash_map.h" 36 #include "utils/ms_context.h" 37 #include "ir/anf.h" 38 #include "ir/func_graph.h" 39 #include "ir/tensor.h" 40 #include "transform/graph_ir/df_graph_manager.h" 41 #include "transform/graph_ir/op_adapter.h" 42 #include "graph/operator_reg.h" 43 #include "ge/ge_api.h" 44 45 namespace mindspore { 46 namespace transform { 47 class BaseOpAdapter; 48 49 using ParamIndexMap = std::map<std::size_t, std::size_t>; 50 enum class GraphType { kNormal, kCond, kBody, kAfter, kBranch }; 51 enum class DfsVisitFlag { kUnVisited, kVisiting, kVisited }; 52 enum class RefModeFlag { 53 kRefModeNone, 54 kRefModeVariable, // Only Variables will be treated as RefData 55 kRefModeAll, // All Parameter including Variables and Constants will be treated as RefData 56 kRefModeEnv // depend on REF_MODE, default value is on, ref mode type will be kRefModeAll 57 }; 58 constexpr char kGraphFlagHasGetNext[] = "graph_has_getnext"; 59 constexpr char kGraphNeedIteration[] = "graph_need_iteration"; 60 61 struct GEInputList { 62 std::vector<AnfNodeWeakPtr> ge_inputs; 63 constexpr static char key[] = "GEInputs"; 64 }; 65 66 class GeOpConvertor { 67 public: 68 static std::map<std::string, ValuePtr> GetAttrAndValue(const AnfNodePtr &node, const bool training); 69 70 static std::string GetOpType(const AnfNodePtr &node, const bool training); 71 72 static std::shared_ptr<GeTensorDesc> GetTensorDesc(const ShapeVector &dev_shape, const TypeId &dev_type, 73 const std::string &dev_format, const ShapeVector &ori_shape, 74 const std::string &ori_format); 75 76 static mindspore::HashMap<std::string, std::string> GetNeedAddInput(const AnfNodePtr &node, const bool training); 77 78 static bool IsDynamicInput(const AnfNodePtr &node, const size_t idx); 79 80 static std::map<int, std::string> GetAclInputNames(const AnfNodePtr &node); 81 82 static std::map<int, std::string> GetAclOutputNames(const AnfNodePtr &node); 83 84 static std::map<int, std::string> GetAclDynamicInputNames(const AnfNodePtr &node); 85 86 static std::map<int, std::string> GetAclDynamicOutputNames(const AnfNodePtr &node); 87 }; 88 89 DfGraphPtr GenExampleGraph(const std::string &name); 90 91 using SetDynRefDataFunc = std::function<ShapeVector(const AnfNodePtr &, const ShapeVector &)>; 92 93 class DfGraphConvertor { 94 public: 95 explicit DfGraphConvertor(const AnfGraphPtr &anf_graph, const std::string &phase_prefix, 96 RefModeFlag ref_mode_type = RefModeFlag::kRefModeEnv, 97 const std::vector<std::string> &extra_variables_names = {}, 98 SetDynRefDataFunc dyn_ref_data_func = nullptr, bool offline_convert = false) anf_graph_(anf_graph)99 : anf_graph_(anf_graph), 100 extra_variables_names_(extra_variables_names), 101 phase_prefix_(phase_prefix), 102 offline_convert_(offline_convert) { 103 MS_EXCEPTION_IF_NULL(anf_graph); 104 if (ref_mode_type == RefModeFlag::kRefModeEnv) { 105 ref_mode_ = IsEnableRefMode(); 106 ref_mode_type_ = RefModeFlag::kRefModeAll; 107 } else { 108 ref_mode_ = (ref_mode_type != RefModeFlag::kRefModeNone); 109 ref_mode_type_ = ref_mode_type; 110 } 111 dyn_ref_data_func_ = dyn_ref_data_func; 112 auto context = MsContext::GetInstance(); 113 MS_EXCEPTION_IF_NULL(context); 114 bool enable_ge = context->backend_policy() == "ge"; 115 bool enable_training = phase_prefix_ == "train"; 116 static bool is_training = false; 117 if (enable_ge && enable_training) { 118 is_training = true; 119 } 120 if (is_training) { 121 training_ = true; 122 } else { 123 training_ = anf_graph->has_flag("training"); 124 } 125 distribute_ = anf_graph->has_flag("broadcast_flag"); 126 if (anf_graph->has_flag("broadcast_flag")) { 127 ConfigManager::GetInstance().set_parallel_strategy(ParallelStrategy::DISTRIBUTION); 128 } else { 129 ConfigManager::GetInstance().set_parallel_strategy(ParallelStrategy::ONE_DEVICE); 130 } 131 is_kernel_graph_ = anf_graph_->type_name() == kKernelGraphTypeName; 132 df_graph_ = std::make_shared<DfGraph>(anf_graph_->ToString()); 133 134 std::string graph_type = is_kernel_graph_ ? "kernel_graph" : "func_graph"; 135 std::string graph_name = anf_graph_->ToString(); 136 graph_manager_ = Manage(anf_graph_, true); 137 MS_EXCEPTION_IF_NULL(graph_manager_); 138 MS_LOG(INFO) << "Create DfGraphConvertor with graph: " << graph_name << "(type: " << graph_type << ")" 139 << ", training: " << training_ << ", dynamic input: " << dynamic_shape_inputs_ 140 << ", distribute: " << distribute_; 141 } 142 ~DfGraphConvertor()143 ~DfGraphConvertor() {} 144 145 static void RegisterAdapter(const std::string &name, OpAdapterPtr adpt); 146 static void RegisterAdapter(const std::string &name, OpAdapterPtr train_adpt, OpAdapterPtr infer_adpt); 147 DrawComputeGraph(const std::string & name)148 void DrawComputeGraph(const std::string &name) { 149 #ifndef ENABLE_SECURITY 150 std::ofstream fout(name); 151 if (!fout.is_open()) { 152 MS_LOG(ERROR) << "Open file '" << name << "' failed!"; 153 return; 154 } 155 fout << compute_sout_.str(); 156 fout.close(); 157 #endif 158 } 159 DrawInitGraph(const std::string & name)160 void DrawInitGraph(const std::string &name) { 161 #ifndef ENABLE_SECURITY 162 std::ofstream fout(name); 163 if (!fout.is_open()) { 164 MS_LOG(ERROR) << "Open file '" << name << "' failed!"; 165 return; 166 } 167 fout << init_sout_.str(); 168 fout.close(); 169 #endif 170 } 171 DrawSaveCheckpointGraph(const std::string & name)172 void DrawSaveCheckpointGraph(const std::string &name) { 173 std::ofstream fout(name); 174 if (!fout.is_open()) { 175 MS_LOG(ERROR) << "Open file '" << name << "' failed!"; 176 return; 177 } 178 fout << checkpoint_sout_.str(); 179 fout.close(); 180 } 181 182 DfGraphConvertor &ConvertAllNode(); 183 void GenFakeGraph(const std::string &name); 184 DfGraphConvertor &BuildGraph(const std::string &name); 185 DfGraphConvertor &InitParam(const TensorOrderMap &tensors); 186 DfGraphConvertor &GenerateCheckpointGraph(); 187 DfGraphConvertor &GenerateBroadcastGraph(const TensorOrderMap &tensors); 188 void InitParamWithData(const TensorOrderMap &tensors); 189 bool NodeInputKeepUpdate(const FuncGraphManagerPtr &manager, const AnfNodePtr &node); 190 OutHandler GetNormalOpInput(const AnfNodePtr &node, const AnfNodePtr &pred); 191 void DrawOpInput(const AnfNodePtr &node, const AnfNodePtr &pred, size_t i); 192 void SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node); 193 void SetOpAttrToInput(const OpAdapterPtr &adpt, const CNodePtr &node); 194 void SetupBroadcast(const OperatorPtr &broadcast, const std::vector<GeTensorDesc> &broadcast_desc, 195 const DfGraphPtr &broadcast_graph, std::vector<::ge::Operator> broadcast_input); 196 void SetupParamInitSubGraph(const TensorOrderMap &tensors, const std::vector<::ge::Operator> *init_input, 197 bool is_sink_size_repeat); 198 void SetupParamInitSubGraph(); 199 void DrawParamInitSubGraph(const std::string &name, const AnfNodePtr &it); 200 201 DfGraphPtr GetComputeGraph(); 202 DfGraphPtr GetInitGraph(); GetInitDataNames()203 std::vector<std::string> GetInitDataNames() const { return init_data_names_; } GetRefDataNames()204 std::vector<std::string> GetRefDataNames() const { return ref_data_names_; } 205 DfGraphPtr GetSaveCheckpointGraph(); 206 DfGraphPtr GetBroadcastGraph(); ErrCode()207 int ErrCode() const { return static_cast<int>(error_); } 208 is_training()209 bool is_training() const { return training_; } set_training(bool is_training)210 void set_training(bool is_training) { training_ = is_training; } 211 export_air()212 bool export_air() const { return export_air_; } set_export_air(bool export_air)213 void set_export_air(bool export_air) { export_air_ = export_air; } dynamic_shape_inputs()214 bool dynamic_shape_inputs() const { return dynamic_shape_inputs_; } input_shapes()215 std::vector<ShapeVector> input_shapes() { return input_shapes_; } 216 217 void SetupInputFormat(const FuncGraphManagerPtr &manager, const AnfNodePtr &node); 218 219 protected: 220 bool InitLoopVar(std::vector<::ge::Operator> *init_input); 221 222 private: 223 std::ostringstream compute_sout_; 224 std::ostringstream init_sout_; 225 std::ostringstream checkpoint_sout_; 226 std::ostringstream restore_checkpoint_sout_; 227 mindspore::HashMap<AnfNode *, std::string> op_draw_name_; 228 std::map<std::string, std::string> param_format_; 229 230 OutHandler GetHandler(const AnfNodePtr &node); 231 OperatorPtr Convert(AnfNodePtr node); 232 OperatorPtr ConvertCNode(CNodePtr node); 233 OperatorPtr ConvertParameter(AnfNodePtr node); 234 void SetNodeAbstract(const CNodePtr &node) const; 235 Status TryConvertValueNodeToMultiConst(const ValueNodePtr node); 236 OperatorPtr ConvertValueNode(ValueNodePtr node); 237 void SaveParamFormat(CNodePtr node); 238 void GetBranchNodeInput(const CNodePtr node); 239 void ConvertTopK(const CNodePtr &node); 240 void ConvertSpaceBatchNd(const FuncGraphPtr anf_graph) const; 241 AnfNodePtr CreateCast(const AnfNodePtr &input, const TypePtr &dst_type) const; 242 void ConvertReshape(const CNodePtr &node); 243 void ConvertHcomFusionId(const CNodePtr &node); 244 void ConvertHcclNode(const CNodePtr &node); 245 void ConvertAllToAllv(const CNodePtr &node); 246 void ConvertUniformReal(const CNodePtr &node); 247 void ConvertUpdateState(const CNodePtr &node); 248 void AddCommAttrForHcclNode(const CNodePtr &node, const OperatorPtr &converted_op) const; 249 void ConvertOCRRecPreHandle(const CNodePtr &node); 250 void ConvertConv2D(const CNodePtr &node); 251 void ConvertDynamicStitch(const CNodePtr &node); 252 void ConvertParallelGroupToHcom(const CNodePtr &node); 253 void ConvertParallelGroupIdToHcom(const CNodePtr &node); 254 std::vector<int64_t> CastToInt(const ValuePtr &value) const; 255 void TransDataType(const FuncGraphPtr &anf_graph) const; 256 void TransInputDataType(const CNodePtr &node, const std::string &node_name) const; 257 void TransAttrDataType(const CNodePtr &node, const std::string &node_name) const; 258 bool CheckCNode(const std::string &name, const CNodePtr node); 259 void SetNodeInput(AnfNodePtr node); 260 void UpdateOpDesc(AnfNodePtr node); 261 void SetSubgraph(const AnfNodePtr &node); 262 void ProcessSubgraph(const AnfNodePtr &node, const AnfNodePtr &branch_node, ParamIndexMap &branch_to_parent_node_map); 263 void BuildSaveCheckpointGraph(); 264 void DrawCNode(const CNodePtr node, const OpAdapterPtr adpt); 265 void UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const; 266 void UpdateConstOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const; 267 void AddGraphConstInput(const OperatorPtr &op); 268 AnfNodePtr ParseLoadInput(const CNodePtr &cnode) const; 269 void SetGraphInputs(std::vector<Operator> *inputs); 270 void SetGraphInputs(std::vector<Operator> *inputs, AnfNodeWeakPtrList *ge_inputs); 271 void TransformConstOp(const CNodePtr &node, const AnfNodePtr &pred); 272 void ProcessInputData(std::vector<Operator> *init_input, 273 std::unordered_set<std::string> *infer_need_update_parameter_names, const OperatorPtr ¶m_op, 274 const string &name, const std::shared_ptr<GeTensorDesc> &desc); 275 AnfNodePtr GetRealInputNode(const CNodePtr &node, const AnfNodePtr &input); 276 277 void ConvertWhileNode(const CNodePtr &node); 278 void CacheWhileGraph(const CNodePtr &cnode); 279 void ConvertWhileBody(const AnfNodePtr &node); 280 std::shared_ptr<std::vector<Operator>> GetWhileSubGraphInput(); 281 void BuildWhileSubGraph(); 282 void ConvertWhileCond(const AnfNodePtr &node); 283 void ConvertWhileAfter(const AnfNodePtr &node); 284 void BuildWhileAfterSubGraph(); 285 void BuildCallSubGraphs(const AnfNodePtr &node); 286 void GetCallNodeInputs(const CNodePtr &node); 287 std::vector<Operator> GetWhileBodyOutputs(); IsSubGraph()288 bool IsSubGraph() const { return graph_type_ == GraphType::kCond || graph_type_ == GraphType::kBody; } IsCondGraph()289 bool IsCondGraph() const { return graph_type_ == GraphType::kCond; } IsBodyGraph()290 bool IsBodyGraph() const { return graph_type_ == GraphType::kBody; } IsBranchGraph()291 bool IsBranchGraph() const { return graph_type_ == GraphType::kBranch; } IsAfterGraph()292 bool IsAfterGraph() const { return graph_type_ == GraphType::kAfter; } IsNormalGraph()293 bool IsNormalGraph() const { return graph_type_ == GraphType::kNormal; } 294 void SetParamIndexMap(const std::vector<AnfNodePtr> &graphs); 295 void SetWhileOutputHandle(const OperatorPtr &prev_while_op); 296 void GetWhileUsedInputIndex(const std::vector<AnfNodePtr> &graphs); 297 298 bool IsDataInput(const AnfNodePtr &node, const AnfNodePtr &input, size_t input_index); 299 void SetMakeTupleInput(const OpAdapterPtr &adpt, const CNodePtr &make_tuple_node); 300 void SetMergeInput(const OpAdapterPtr &adpt, const CNodePtr &merge_node); 301 bool IsMergeOrSwitchLayerInput(const CNodePtr &node) const; 302 void SetDynamicInputHandleByMultiInput(const OpAdapterPtr &adpt, const CNodePtr &node, 303 const CNodePtr &from_node_input); 304 void SetNodeControlInput(const AnfNodePtr &node, const AnfNodePtr &input); 305 void SetGraphOutputs(bool is_main_graph = false); 306 std::vector<OutHandler> GetInputHandles(const AnfNodePtr &node, const AnfNodePtr &input); 307 void FillEmptyInputsWithNoInputOp(std::vector<Operator> *); 308 bool IsDynamicInputBeforeNormalInput(const OpAdapterPtr &adpt, int *ge_input_size, 309 mindspore::HashMap<int, int> *ge_input_to_ms_input); 310 void SetDynamicInputBeforeNormalInput(const OpAdapterPtr &adpt, const CNodePtr &node, 311 const std::vector<AnfNodePtr> &inputs, const int &ge_input_size, 312 const mindspore::HashMap<int, int> &ge_input_to_ms_input, 313 std::vector<int64_t> *dyn_input_sizes); 314 315 // Identity Optimization 316 void IdentityOptimization(); 317 std::string GetGNodeName(const ::ge::GNode &node) const; 318 std::string GetGNodeType(const ::ge::GNode &node) const; 319 bool IsIdentityRedundant(const ::ge::GNode &node) const; 320 void RemoveIdentity(::ge::GNode identity_node); 321 void NoOpOptimization(); 322 bool IsNoOpRedundant(const ::ge::GNode &node) const; 323 void RemoveNoOp(::ge::GNode noop); 324 std::shared_ptr<std::vector<DfGraph>> BuildBranchGraphs(const CNodePtr &cnode); 325 void BuildInitDataGraph(const std::string &name); 326 bool IsConstantOp(const OperatorPtr &op) const; 327 void JudgeParamTransType(const bool &node_will_update, bool *as_ref_data, bool *as_constant) const; 328 OperatorPtr SetGraphInputsForNotVar(const AnfNodePtr &it, int64_t *index, std::vector<Operator> *inputs); 329 void GenFakeGraphInRefMode(); 330 void AddInputAttrsForESNode(const CNodePtr &node, const AnfNodePtr &input); 331 void RemoveIdentityForES(::ge::GNode node); 332 void ESOptimization(); 333 void ReplaceAllParameterToRefData(); 334 335 std::shared_ptr<AnfGraph> anf_graph_{nullptr}; 336 FuncGraphManagerPtr graph_manager_{nullptr}; 337 RefModeFlag ref_mode_type_ = RefModeFlag::kRefModeNone; 338 bool ref_mode_ = false; 339 std::vector<std::string> extra_variables_names_; 340 std::vector<std::string> ref_data_names_; 341 std::set<std::string> unsupported_ops_names_; 342 SetDynRefDataFunc dyn_ref_data_func_ = nullptr; 343 344 std::shared_ptr<DfGraph> df_graph_{nullptr}; 345 std::shared_ptr<DfGraph> init_graph_{nullptr}; 346 std::shared_ptr<DfGraph> save_ckp_graph_{nullptr}; 347 std::shared_ptr<DfGraph> restore_ckp_graph_{nullptr}; 348 std::shared_ptr<DfGraph> broadcast_graph_{nullptr}; 349 mindspore::HashMap<AnfNode *, DfGraph> branches_map_; 350 mindspore::HashMap<AnfNode *, OperatorPtr> op_cache_; 351 /* record "getnext"<->"out_handler" mapping */ 352 mindspore::HashMap<AnfNode *, OutHandler> out_handle_cache_; 353 /* record "value tuple"<->"out_handler vector" mapping */ 354 mindspore::HashMap<AnfNode *, std::shared_ptr<std::vector<OutHandler>>> tuple_out_handle_cache_; 355 mindspore::HashMap<AnfNode *, std::shared_ptr<std::vector<AnfNodePtr>>> branch_input_handle_cache_; 356 mindspore::HashMap<std::string, AnfNodePtr> params_; 357 mindspore::HashMap<std::string, OperatorPtr> vars_; 358 std::vector<OperatorPtr> ref_datas_; 359 std::vector<std::pair<::ge::Operator, std::string>> graph_outputs_; 360 std::vector<AnfNodePtr> graph_anf_outputs_; 361 std::vector<OperatorPtr> graph_const_inputs_; 362 std::vector<OperatorPtr> init_ops_; 363 std::vector<std::string> init_data_names_; 364 std::vector<OperatorPtr> broadcast_ops_; 365 std::vector<AnfNodePtr> inputs_; 366 ShapeArray input_shapes_; 367 Status error_ = SUCCESS; 368 bool training_ = false; 369 bool export_air_ = false; 370 bool distribute_ = false; 371 bool use_inputs_ = false; 372 bool dynamic_shape_inputs_ = false; 373 bool has_es_node_ = false; 374 375 AnfNodePtr while_cond_node_ = nullptr; 376 mindspore::HashMap<AnfNodePtr, std::shared_ptr<std::vector<DfGraph>>> while_dfgraph_cache_; 377 mindspore::HashMap<AnfNodePtr, std::shared_ptr<std::vector<DfGraph>>> call_dfgraph_cache_; 378 CNodePtr cur_while_node_ = nullptr; 379 size_t cur_while_node_out_size_ = 0; 380 mindspore::HashMap<size_t, OutHandler> while_const_input_index_; 381 mindspore::HashMap<size_t, OutHandler> prev_while_const_input_index_; 382 mindspore::HashMap<size_t, size_t> prev_cond_to_while_out_index_; 383 mindspore::HashMap<OperatorPtr, std::shared_ptr<tensor::Tensor>> const_op_to_value_; 384 AnfNodePtr prev_while_node_ = nullptr; 385 size_t prev_while_node_out_size_ = 0; 386 387 mindspore::HashMap<AnfNodePtr, std::vector<AnfNodePtr>> while_graph_cache_; 388 mindspore::HashMap<AnfNodePtr, std::shared_ptr<std::vector<OutHandler>>> call_input_handle_cache_; 389 mindspore::HashMap<AnfNodePtr, std::shared_ptr<std::vector<OutHandler>>> while_output_handle_cache_; 390 AnfNodePtr call_node_in_while_body_ = nullptr; 391 GraphType graph_type_ = GraphType::kNormal; 392 393 ParamIndexMap body_cond_map_; 394 ParamIndexMap after_cond_map_; 395 ParamIndexMap prev_after_cond_map_; 396 mindspore::HashMap<size_t, OperatorPtr> subgraph_input_cache_; 397 398 std::set<size_t> while_used_input_index_; 399 std::set<size_t> prev_while_used_input_index_; 400 401 mindspore::HashMap<size_t, OutHandler> bypass_node_prev_handle_cache_; 402 mindspore::HashMap<size_t, OutHandler> bypass_node_handle_cache_; 403 size_t case_call_input_size_ = 0; 404 bool is_kernel_graph_ = false; 405 406 std::string phase_prefix_; 407 bool offline_convert_ = false; 408 void AddInputInDataSink(std::vector<Operator> *inputs); 409 }; 410 } // namespace transform 411 } // namespace mindspore 412 413 #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_CONVERT_H_ 414