1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_COST_MODEL_H_ 17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_COST_MODEL_H_ 18 19 #include <map> 20 #include <memory> 21 #include <set> 22 #include <sstream> 23 #include <string> 24 #include <tuple> 25 #include <vector> 26 27 #include "base/base.h" 28 #include "include/backend/anf_runtime_algorithm.h" 29 #include "include/common/utils/anfalgo.h" 30 #include "include/backend/optimizer/optimizer.h" 31 #include "backend/common/graph_kernel/parallel_cost_model.h" 32 #include "include/backend/kernel_graph.h" 33 #include "include/common/utils/python_adapter.h" 34 #include "utils/ms_context.h" 35 36 namespace mindspore::graphkernel { 37 class DimInfo { 38 public: 39 DimInfo() = default; ~DimInfo()40 virtual ~DimInfo() {} 41 virtual std::string ToString() = 0; 42 }; 43 44 class CommonDimInfo : public DimInfo { 45 public: CommonDimInfo(size_t dim)46 explicit CommonDimInfo(size_t dim) : dim_info_(dim) {} ~CommonDimInfo()47 ~CommonDimInfo() {} set_dim_info(size_t d)48 void set_dim_info(size_t d) { dim_info_ = d; } dim_info()49 size_t dim_info() const { return dim_info_; } 50 std::string ToString() override; 51 52 private: 53 size_t dim_info_; 54 }; 55 56 using DimInfoPtr = std::shared_ptr<DimInfo>; 57 using CommonDimInfoPtr = std::shared_ptr<CommonDimInfo>; 58 59 class FusionInfo { 60 public: 61 FusionInfo() = default; FusionInfo(const std::string & type)62 explicit FusionInfo(const std::string &type) : fusion_type_(type) {} 63 virtual ~FusionInfo() = default; FusionType()64 std::string FusionType() const { return fusion_type_; } ExistTypeInfo()65 virtual bool ExistTypeInfo() { return false; } 66 67 private: 68 std::string fusion_type_{"none"}; 69 }; 70 71 class BlockFusionInfo : public FusionInfo { 72 public: BlockFusionInfo()73 BlockFusionInfo() : FusionInfo("block_fusion") {} 74 ~BlockFusionInfo() = default; ExistTypeInfo()75 bool ExistTypeInfo() override { return false; } 76 }; 77 78 class BlockPipelineFusionInfo : public FusionInfo { 79 public: BlockPipelineFusionInfo(const std::vector<std::vector<int>> & ids)80 explicit BlockPipelineFusionInfo(const std::vector<std::vector<int>> &ids) 81 : FusionInfo("block_pipeline_fusion"), pipeline_ids_(ids) {} 82 ~BlockPipelineFusionInfo() = default; ExistTypeInfo()83 bool ExistTypeInfo() override { return true; } PipelineIds()84 std::vector<std::vector<int>> PipelineIds() { return pipeline_ids_; } 85 86 private: 87 std::vector<std::vector<int>> pipeline_ids_; 88 }; 89 90 using FusionInfoPtr = std::shared_ptr<FusionInfo>; 91 using BlockFusionInfoPtr = std::shared_ptr<BlockFusionInfo>; 92 using BlockPipelineFusionInfoPtr = std::shared_ptr<BlockPipelineFusionInfo>; 93 94 class ParallelCostModel { 95 public: ParallelCostModel()96 ParallelCostModel() {} ~ParallelCostModel()97 ~ParallelCostModel() {} 98 int64_t GetNodeCalAmount(const AnfNodePtr &node) const; 99 std::tuple<std::vector<DimInfoPtr>, int, FusionInfoPtr> CalFuseInfo(const AnfNodePtrList &nodes) const; 100 101 private: 102 FusionInfoPtr ProcessFusionInfo(const py::object &fusion_type, const py::object &type_info) const; 103 }; 104 105 using ParallelCostModelPtr = std::shared_ptr<ParallelCostModel>; 106 107 class ParellelCostModelWarehouse { 108 public: Instance()109 static ParellelCostModelWarehouse &Instance() { 110 static ParellelCostModelWarehouse instance = ParellelCostModelWarehouse(); 111 return instance; 112 } 113 ParallelCostModelPtr GetParallelCostModel(const std::string &target) const; 114 115 private: ParellelCostModelWarehouse()116 ParellelCostModelWarehouse() { cost_model_ = std::make_shared<ParallelCostModel>(); } 117 ~ParellelCostModelWarehouse() = default; 118 ParallelCostModelPtr cost_model_; 119 }; 120 } // namespace mindspore::graphkernel 121 #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_COST_MODEL_H_ 122