1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_AOE_UTIL_H_ 18 #define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_AOE_UTIL_H_ 19 #include <map> 20 #include <memory> 21 #include <set> 22 #include <string> 23 #include <vector> 24 #include "include/transform/graph_ir/types.h" 25 #include "utils/dlopen_macro.h" 26 27 namespace mindspore { 28 namespace transform { 29 30 using AoeStatus = int32_t; 31 constexpr AoeStatus AOE_SUCCESS = 0; 32 constexpr AoeStatus AOE_ERROR_NON_OPTIMIZE_GRAPH = 8; 33 34 ORIGIN_METHOD(AoeInitialize, AoeStatus, const std::map<::ge::AscendString, ::ge::AscendString> &); 35 ORIGIN_METHOD(AoeFinalize, AoeStatus); 36 ORIGIN_METHOD(AoeCreateSession, AoeStatus, uint64_t &); 37 ORIGIN_METHOD(AoeSetGeSession, AoeStatus, uint64_t, ::ge::Session *); 38 ORIGIN_METHOD(AoeSetTuningGraph, AoeStatus, uint64_t, const ::ge::Graph &); 39 ORIGIN_METHOD(AoeTuningGraph, AoeStatus, uint64_t, const std::map<::ge::AscendString, ::ge::AscendString> &); 40 ORIGIN_METHOD(AoeDestroySession, AoeStatus, uint64_t); 41 42 class AoeUtil { 43 public: 44 Status AoeOnlineGeGraph(const std::shared_ptr<::ge::Session> &ge_session, const transform::DfGraphPtr &graph) const; 45 static AoeUtil &GetInstance(); 46 ~AoeUtil(); 47 void Initialize(); 48 void Destroy(); 49 void SaveOptimizedGraph(const int32_t &graph_id); 50 bool IsSaveOptimizedGraph(const int32_t &graph_id) const; 51 void RemoveWaitOptimizedGraph(const std::set<std::string> &optimized_graph_names); 52 void AddOptimizeGraph(const std::string &graph_name); 53 std::set<std::string> GetWaitOptimizeGraph() const; 54 void SetOfflineEnvDumpGeGraph(); 55 56 private: 57 std::set<std::string> wait_optimize_graphs_; 58 std::set<int32_t> optimized_graphs_id_; 59 AoeUtil(); 60 bool initialize_; 61 62 Status AoeGeGraph(::ge::Session *ge_session, const transform::DfGraphPtr &graph, 63 const std::map<::ge::AscendString, ::ge::AscendString> &tuningOptions) const; 64 void *plugin_handle_ = nullptr; 65 std::vector<void *> depend_handler_; 66 AoeInitializeFunObj aoe_initialize_ = nullptr; 67 AoeFinalizeFunObj aoe_finalize_ = nullptr; 68 AoeCreateSessionFunObj aoe_create_session_ = nullptr; 69 AoeSetGeSessionFunObj aoe_set_ge_gession_ = nullptr; 70 AoeSetTuningGraphFunObj aoe_set_tuning_graph_ = nullptr; 71 AoeTuningGraphFunObj aoe_tuning_graph_ = nullptr; 72 AoeDestroySessionFunObj aoe_destroy_session_ = nullptr; 73 }; 74 } // namespace transform 75 } // namespace mindspore 76 77 #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_AOE_UTIL_H_ 78