• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_COST_MODEL_H_
17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_COST_MODEL_H_
18 
19 #include <map>
20 #include <memory>
21 #include <set>
22 #include <sstream>
23 #include <string>
24 #include <tuple>
25 #include <vector>
26 
27 #include "base/base.h"
28 #include "include/backend/anf_runtime_algorithm.h"
29 #include "include/common/utils/anfalgo.h"
30 #include "include/backend/optimizer/optimizer.h"
31 #include "backend/common/graph_kernel/parallel_cost_model.h"
32 #include "include/backend/kernel_graph.h"
33 #include "include/common/utils/python_adapter.h"
34 #include "utils/ms_context.h"
35 
36 namespace mindspore::graphkernel {
37 class DimInfo {
38  public:
39   DimInfo() = default;
~DimInfo()40   virtual ~DimInfo() {}
41   virtual std::string ToString() = 0;
42 };
43 
44 class CommonDimInfo : public DimInfo {
45  public:
CommonDimInfo(size_t dim)46   explicit CommonDimInfo(size_t dim) : dim_info_(dim) {}
~CommonDimInfo()47   ~CommonDimInfo() {}
set_dim_info(size_t d)48   void set_dim_info(size_t d) { dim_info_ = d; }
dim_info()49   size_t dim_info() const { return dim_info_; }
50   std::string ToString() override;
51 
52  private:
53   size_t dim_info_;
54 };
55 
56 using DimInfoPtr = std::shared_ptr<DimInfo>;
57 using CommonDimInfoPtr = std::shared_ptr<CommonDimInfo>;
58 
59 class FusionInfo {
60  public:
61   FusionInfo() = default;
FusionInfo(const std::string & type)62   explicit FusionInfo(const std::string &type) : fusion_type_(type) {}
63   virtual ~FusionInfo() = default;
FusionType()64   std::string FusionType() const { return fusion_type_; }
ExistTypeInfo()65   virtual bool ExistTypeInfo() { return false; }
66 
67  private:
68   std::string fusion_type_{"none"};
69 };
70 
71 class BlockFusionInfo : public FusionInfo {
72  public:
BlockFusionInfo()73   BlockFusionInfo() : FusionInfo("block_fusion") {}
74   ~BlockFusionInfo() = default;
ExistTypeInfo()75   bool ExistTypeInfo() override { return false; }
76 };
77 
78 class BlockPipelineFusionInfo : public FusionInfo {
79  public:
BlockPipelineFusionInfo(const std::vector<std::vector<int>> & ids)80   explicit BlockPipelineFusionInfo(const std::vector<std::vector<int>> &ids)
81       : FusionInfo("block_pipeline_fusion"), pipeline_ids_(ids) {}
82   ~BlockPipelineFusionInfo() = default;
ExistTypeInfo()83   bool ExistTypeInfo() override { return true; }
PipelineIds()84   std::vector<std::vector<int>> PipelineIds() { return pipeline_ids_; }
85 
86  private:
87   std::vector<std::vector<int>> pipeline_ids_;
88 };
89 
90 using FusionInfoPtr = std::shared_ptr<FusionInfo>;
91 using BlockFusionInfoPtr = std::shared_ptr<BlockFusionInfo>;
92 using BlockPipelineFusionInfoPtr = std::shared_ptr<BlockPipelineFusionInfo>;
93 
94 class ParallelCostModel {
95  public:
ParallelCostModel()96   ParallelCostModel() {}
~ParallelCostModel()97   ~ParallelCostModel() {}
98   int64_t GetNodeCalAmount(const AnfNodePtr &node) const;
99   std::tuple<std::vector<DimInfoPtr>, int, FusionInfoPtr> CalFuseInfo(const AnfNodePtrList &nodes) const;
100 
101  private:
102   FusionInfoPtr ProcessFusionInfo(const py::object &fusion_type, const py::object &type_info) const;
103 };
104 
105 using ParallelCostModelPtr = std::shared_ptr<ParallelCostModel>;
106 
107 class ParellelCostModelWarehouse {
108  public:
Instance()109   static ParellelCostModelWarehouse &Instance() {
110     static ParellelCostModelWarehouse instance = ParellelCostModelWarehouse();
111     return instance;
112   }
113   ParallelCostModelPtr GetParallelCostModel(const std::string &target) const;
114 
115  private:
ParellelCostModelWarehouse()116   ParellelCostModelWarehouse() { cost_model_ = std::make_shared<ParallelCostModel>(); }
117   ~ParellelCostModelWarehouse() = default;
118   ParallelCostModelPtr cost_model_;
119 };
120 }  // namespace mindspore::graphkernel
121 #endif  // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_COST_MODEL_H_
122