• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/optimizer/graph_kernel/parallel_cost_model.h"
18 
19 #include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
20 #include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
21 #include "pipeline/jit/parse/python_adapter.h"
22 
23 namespace mindspore {
24 namespace opt {
ToString()25 std::string CommonDimInfo::ToString() {
26   std::ostringstream buffer;
27   buffer << "Dim(" << dim_info_ << ")";
28   return buffer.str();
29 }
30 
GetNodeCalAmount(const AnfNodePtr & node) const31 int ParallelCostModel::GetNodeCalAmount(const AnfNodePtr &node) const {
32   nlohmann::json json_desc;
33   AnfNodePtrList nodes = {node};
34   DumpOption dump_option;
35   if (!AnfToJsonDesc(nodes, dump_option, &json_desc)) {
36     MS_LOG(EXCEPTION) << "Collect json desc failed.";
37   }
38 
39   auto json_desc_str = json_desc.dump();
40   auto ret = parse::python_adapter::CallPyFn(kGraphKernelModule, kGraphKernelGetNodeCalAmount, json_desc_str);
41   if (py::isinstance<py::none>(ret)) {
42     MS_LOG(EXCEPTION) << "CallPyFn: [" << kGraphKernelSplitFunc << "] return invalid result. input json:\n"
43                       << json_desc_str;
44   }
45   return py::cast<int>(ret);
46 }
47 
CalFuseInfo(const AnfNodePtrList & nodes) const48 std::tuple<std::vector<DimInfoPtr>, int, FusionInfoPtr> ParallelCostModel::CalFuseInfo(
49   const AnfNodePtrList &nodes) const {
50   nlohmann::json json_desc;
51   std::vector<AnfNodePtrList> graphs;
52   std::transform(nodes.begin(), nodes.end(), std::back_inserter(graphs),
53                  [](const AnfNodePtr &node) -> AnfNodePtrList { return {node}; });
54   DumpOption dump_option;
55   if (!AnfToJsonDesc(graphs, dump_option, &json_desc)) {
56     MS_LOG(EXCEPTION) << "Collect json desc failed.";
57   }
58 
59   auto json_desc_str = json_desc.dump();
60   auto ret = parse::python_adapter::CallPyFn(kGraphKernelModule, kGraphKernelEstimateOps, json_desc_str);
61   if (py::isinstance<py::none>(ret)) {
62     MS_LOG(EXCEPTION) << "CallPyFn: [" << kGraphKernelSplitFunc << "] return invalid result. input json:\n"
63                       << json_desc_str;
64   }
65 
66   py::tuple ret_tuple = py::cast<py::tuple>(ret);
67   if (!py::isinstance<py::tuple>(ret_tuple) || ret_tuple.size() != 4) {
68     MS_LOG(EXCEPTION) << "Parallel cost model should return a tuple with two elements!";
69   }
70 
71   std::vector<DimInfoPtr> dim_infos;
72   py::list dim_list = py::cast<py::list>(ret_tuple[0]);
73   for (size_t i = 0; i < dim_list.size(); ++i) {
74     dim_infos.push_back(std::make_shared<CommonDimInfo>(py::cast<int>(dim_list[i])));
75   }
76   int benefit = py::cast<int>(ret_tuple[1]);
77   auto fusion_info = ProcessFusionInfo(ret_tuple[2], ret_tuple[3]);
78 
79   return std::make_tuple(dim_infos, benefit, fusion_info);
80 }
81 
ProcessFusionInfo(const py::object & fusion_type,const py::object & type_info) const82 FusionInfoPtr ParallelCostModel::ProcessFusionInfo(const py::object &fusion_type, const py::object &type_info) const {
83   if (!py::isinstance<py::str>(fusion_type)) {
84     MS_LOG(EXCEPTION) << "Fusion type for parallel is invalid!";
85   }
86 
87   std::string fusion_type_name = py::cast<std::string>(fusion_type);
88 
89   FusionInfoPtr fusion_info;
90   if (fusion_type_name == "block_fusion") {
91     fusion_info = std::make_shared<BlockFusionInfo>();
92   } else if (fusion_type_name == "block_pipeline_fusion") {
93     if (!py::isinstance<py::list>(type_info)) {
94       MS_LOG(EXCEPTION) << "Fusion type info for block pipe fusion type is invalid!";
95     }
96     std::vector<std::vector<int>> pipeline_ids;
97     py::list pipeline_ids_list = py::cast<py::list>(type_info);
98     for (size_t i = 0; i < pipeline_ids_list.size(); ++i) {
99       std::vector<int> part_ids;
100       py::list inner_ids_list = py::cast<py::list>(pipeline_ids_list[i]);
101       for (size_t j = 0; j < inner_ids_list.size(); ++j) {
102         part_ids.push_back(py::cast<int>(inner_ids_list[j]));
103       }
104       pipeline_ids.push_back(part_ids);
105     }
106 
107     fusion_info = std::make_shared<BlockPipelineFusionInfo>(pipeline_ids);
108   } else {
109     MS_LOG(EXCEPTION) << "Unsupported parallel fusion type: " << fusion_type_name;
110   }
111   return fusion_info;
112 }
113 
GetParallelCostModel(const std::string & target) const114 ParallelCostModelPtr ParellelCostModelWarehouse::GetParallelCostModel(const std::string &target) const {
115   if (target != kGPUDevice) {
116     MS_LOG(EXCEPTION) << "Parallel cost model only support " << kGPUDevice << " now.";
117   }
118   return cost_model_;
119 }
120 }  // namespace opt
121 }  // namespace mindspore
122