1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_INCLUDE_TRANSFORM_GRAPH_IR_UTILS_H_ 18 #define MINDSPORE_CCSRC_INCLUDE_TRANSFORM_GRAPH_IR_UTILS_H_ 19 #include <string> 20 #include <map> 21 #include <memory> 22 #include <vector> 23 #include "transform/graph_ir/convert.h" 24 #include "transform/graph_ir/graph_runner.h" 25 #include "include/transform/graph_ir/types.h" 26 #include "transform/graph_ir/op_adapter_base.h" 27 #include "include/common/utils/config_manager.h" 28 #include "include/backend/visible.h" 29 30 namespace mindspore { 31 constexpr char BROADCAST_GRAPH_NAME[] = "broadcast_subgraph"; 32 33 namespace transform { 34 using OpAdapterPtr = std::shared_ptr<transform::BaseOpAdapter>; 35 using GraphRunnerPtr = std::shared_ptr<transform::GraphRunner>; 36 using DfGraphConvertorPtr = std::shared_ptr<transform::DfGraphConvertor>; 37 OpAdapterPtr FindAdapter(const std::string &op_name, bool train = false); 38 OpAdapterPtr FindAdapter(AnfNodePtr node, bool train = false); 39 40 bool IsPartialSuccNode(const AnfNodePtr node); 41 bool IsWhileNode(const AnfNodePtr &node); 42 bool IsCallNode(const AnfNodePtr &node); 43 bool IsIfNode(const AnfNodePtr &node); 44 bool IsCaseNode(const AnfNodePtr &node); 45 std::string GetCNodeTargetFuncName(const CNodePtr cnode); 46 bool IsPartialCNode(const AnfNodePtr node); 47 bool IsInitDataSetQueueNode(const AnfNodePtr &node); 48 49 void ClearGeSessionAndRunner(); 50 void InitializeAoeUtil(); 51 void DestroyAoeUtil(); 52 void EnableAoeOffline(); 53 54 // convert_type 55 std::vector<GeTensorPtr> ConvertInputTensors(const std::vector<MeTensorPtr> &me_tensors, const std::string &format); 56 std::vector<MeTensorPtr> ConvertGeTensors(const std::vector<GeTensorPtr> &ge_tensors); 57 GeDataType ConvertDataType(const MeDataType &type); 58 59 MeTensorPtr ConvertGeTensor(const GeTensorPtr &ge_tensor, const ShapeVector &request_dims, bool ref_mem = false); 60 MeTensorPtr ConvertGeTensor(const GeTensorPtr &tensor); 61 MeTensorPtr ConvertGeTensor(const GeTensorPtr &tensor, const TypeId &me_type); 62 63 // df graph manager 64 std::shared_ptr<transform::GraphRunner> GetGraphRunner(); 65 std::shared_ptr<transform::GraphRunner> CheckAndGetGraphRunner(const transform::RunOptions &run_options); 66 BACKEND_EXPORT std::shared_ptr<::ge::Session> GetGeSession(); 67 BACKEND_EXPORT void SetGeSession(const std::shared_ptr<::ge::Session> &sess_ptr); 68 BACKEND_EXPORT GraphRunnerPtr NewGraphRunner(const GraphRunnerOptions &options); 69 BACKEND_EXPORT void SetGraphRunner(const GraphRunnerPtr &runner); 70 BACKEND_EXPORT void ClearGraph(); 71 BACKEND_EXPORT Status AddGraph(const std::string &name, const DfGraphPtr &graph, const OptionMap &options = {}, 72 const bool &is_cloud = false, const bool &need_aoe = false); 73 BACKEND_EXPORT void SetAnfGraph(const std::string &name, const AnfGraphPtr &anf_graph_ptr); 74 BACKEND_EXPORT DfGraphWrapperPtr GetGraphByName(const std::string &name); 75 BACKEND_EXPORT void AddOptimizeGraph(const std::string &name); 76 77 FuncGraphPtr GetAnfGraph(uint32_t graph_id); 78 79 // convert 80 BACKEND_EXPORT DfGraphConvertorPtr NewConverter(const FuncGraphPtr &graph, const std::string &phase_prefix = "", 81 RefModeFlag ref_mode_type = RefModeFlag::kRefModeEnv, 82 bool offline_convert = false); 83 84 BACKEND_EXPORT void SetTraining(const DfGraphConvertorPtr &converter, bool training); 85 BACKEND_EXPORT void SetExportAir(const DfGraphConvertorPtr &converter, bool export_air); 86 BACKEND_EXPORT void BuildGraph(const std::string &name, const DfGraphConvertorPtr &converter, 87 const std::map<std::string, std::shared_ptr<tensor::Tensor>> &maps); 88 void GenerateBroadcastGraph(const DfGraphConvertorPtr &converter, const TensorOrderMap &tensors); 89 BACKEND_EXPORT void GenerateCheckpointGraph(const DfGraphConvertorPtr &converter); 90 BACKEND_EXPORT int ErrCode(const DfGraphConvertorPtr &converter); 91 BACKEND_EXPORT bool ConvertCheck(const AnfNodePtr &node); 92 BACKEND_EXPORT bool DynamicShapeSupportCheck(const AnfNodePtr &node, bool train = true); 93 BACKEND_EXPORT bool SinkGraphCheck(const AnfNodePtr &node, bool train = true); 94 BACKEND_EXPORT void GenFakeGraph(const std::string &name, const DfGraphConvertorPtr &converter); 95 96 BACKEND_EXPORT DfGraphPtr GetComputeGraph(const DfGraphConvertorPtr &converter); 97 BACKEND_EXPORT DfGraphPtr GetInitGraph(const DfGraphConvertorPtr &converter); 98 BACKEND_EXPORT DfGraphPtr GetSaveCheckpointGraph(const DfGraphConvertorPtr &converter); 99 BACKEND_EXPORT DfGraphPtr GetBroadcastGraph(const DfGraphConvertorPtr &converter); 100 101 // new session 102 BACKEND_EXPORT std::shared_ptr<::ge::Session> NewSession(const SessionOptions &sess_options); 103 104 Status RunGraph(const std::shared_ptr<GraphRunner> &runner, const RunOptions &options, 105 const std::vector<GeTensorPtr> &inputs, std::vector<GeTensorPtr> *outputs); 106 107 Status RunGraphAsync(const std::shared_ptr<GraphRunner> &runner, const RunOptions &options, 108 const std::vector<GeTensorPtr> &inputs, std::vector<GeTensorPtr> *outputs); 109 110 Status RunGraphWithStreamAsync(const std::shared_ptr<GraphRunner> &runner, const RunOptions &options, void *stream, 111 const std::vector<GeTensor> &inputs, std::vector<GeTensor> *outputs); 112 113 Status RegisterExternalAllocator(const std::shared_ptr<GraphRunner> &runner, const void *const stream, 114 GeAllocatorPtr allocator); 115 116 Status UnregisterExternalAllocator(const std::shared_ptr<GraphRunner> &runner, const void *const stream); 117 118 transform::Status CompileDatasetGraph(const DatasetGraphParam ¶m, const std::string &phase = "dataset"); 119 } // namespace transform 120 } // namespace mindspore 121 #endif // MINDSPORE_CCSRC_INCLUDE_TRANSFORM_GRAPH_IR_UTILS_H_ 122