1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/auto_parallel/costmodel.h"
18 #include <cmath>
19 #include <numeric>
20 #include <utility>
21 #include "frontend/parallel/auto_parallel/graph_costmodel.h"
22
23 namespace mindspore {
24 namespace parallel {
Simplify(CostPtrList * clist_ptrs)25 void Simplify(CostPtrList *clist_ptrs) {
26 const auto run_phase = CostModelContext::GetInstance()->run_phase();
27 if (run_phase == TRAINING_PHASE) {
28 // training phase
29 SimplifyForDecreasingCommunicationWithPartialPara(clist_ptrs);
30 } else {
31 // inference phase
32 SimplifyForDecreasingCommunicationForward(clist_ptrs);
33 }
34 }
SimplifyForDecreasingCommunicationForward(CostPtrList * clist_ptrs)35 void SimplifyForDecreasingCommunicationForward(CostPtrList *clist_ptrs) {
36 // Sort the cost_list with the computation_cost_ increasing, and communication_forward decreasing order. This method
37 // excludes the cost with greater computation_cost_ and greater communication_forward.
38 // E.g. clist_ptrs = {<100, 20>, <200, 10>, <300, 50>}. After this method, clist_ptrs = {<200, 10>, <100, 20>}
39 const auto simplify_cal = CostModelContext::GetInstance()->costmodel_simplify_cal();
40 if (!simplify_cal) {
41 return;
42 }
43 MS_EXCEPTION_IF_NULL(clist_ptrs);
44 std::vector<size_t> id(clist_ptrs->size());
45 std::iota(id.begin(), id.end(), size_t(0));
46 std::sort(id.begin(), id.end(), [&clist_ptrs](size_t x, size_t y) {
47 return clist_ptrs->at(x)->computation_cost_ < clist_ptrs->at(y)->computation_cost_;
48 });
49 CostPtrList ret;
50 for (size_t i = 0; i < clist_ptrs->size(); ++i) {
51 if ((ret.size() == size_t(0)) ||
52 (clist_ptrs->at(id[i])->communication_forward_ < ret.back()->communication_forward_)) {
53 ret.emplace_back(std::move(clist_ptrs->at(id[i])));
54 }
55 }
56 *clist_ptrs = std::move(ret);
57 }
58
SimplifyForDecreasingCommunicationWithPartialPara(CostPtrList * clist_ptrs)59 void SimplifyForDecreasingCommunicationWithPartialPara(CostPtrList *clist_ptrs) {
60 // Sort the cost_list with the computation_cost_ increasing, and communication_with_partial_para_cost decreasing
61 // order. This method excludes the cost with greater computation_cost_ and greater communication_without_para_cost.
62 const auto simplify_cal = CostModelContext::GetInstance()->costmodel_simplify_cal();
63 if (!simplify_cal) {
64 return;
65 }
66 MS_EXCEPTION_IF_NULL(clist_ptrs);
67 std::vector<size_t> id(clist_ptrs->size());
68 std::iota(id.begin(), id.end(), size_t(0));
69 std::sort(id.begin(), id.end(), [&clist_ptrs](size_t x, size_t y) {
70 return clist_ptrs->at(x)->computation_cost_ < clist_ptrs->at(y)->computation_cost_;
71 });
72 CostPtrList ret;
73 for (size_t i = 0; i < clist_ptrs->size(); ++i) {
74 if ((ret.size() == size_t(0)) ||
75 (clist_ptrs->at(id[i])->communication_with_partial_para_ < ret.back()->communication_with_partial_para_)) {
76 ret.emplace_back(std::move(clist_ptrs->at(id[i])));
77 }
78 }
79 *clist_ptrs = std::move(ret);
80 }
81
RefineForPracticalCost(const CostPtr & origin_cost,bool is_redistribution)82 void RefineForPracticalCost(const CostPtr &origin_cost, bool is_redistribution) {
83 MS_EXCEPTION_IF_NULL(origin_cost);
84 const auto comm_threshold = CostModelContext::GetInstance()->costmodel_communi_threshold();
85 const auto comm_const = CostModelContext::GetInstance()->costmodel_communi_const();
86 const auto comm_bias = CostModelContext::GetInstance()->costmodel_communi_bias();
87 const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
88 if (is_redistribution) {
89 // Redistribution cost
90 if ((origin_cost->communication_redis_forward_ > EPS) &&
91 (origin_cost->communication_redis_forward_ <= comm_threshold)) {
92 origin_cost->communication_redis_forward_ = comm_const;
93 } else if (origin_cost->communication_redis_forward_ > comm_threshold) {
94 origin_cost->communication_redis_forward_ += comm_bias;
95 }
96 if ((origin_cost->communication_redis_backward_ > EPS) &&
97 (origin_cost->communication_redis_backward_ <= comm_threshold)) {
98 origin_cost->communication_redis_backward_ = comm_const;
99 } else if (origin_cost->communication_redis_backward_ > comm_threshold) {
100 origin_cost->communication_redis_backward_ += comm_bias;
101 }
102 origin_cost->communication_cost_ =
103 origin_cost->communication_redis_forward_ + origin_cost->communication_redis_backward_;
104 origin_cost->communication_without_parameter_ = origin_cost->communication_cost_;
105 origin_cost->communication_with_partial_para_ = origin_cost->communication_cost_;
106 } else {
107 // Operator cost
108 double backward = 0.0;
109 if (std::abs(origin_cost->communication_cost_ - origin_cost->communication_without_parameter_) > EPS) {
110 backward = origin_cost->communication_cost_ - origin_cost->communication_without_parameter_;
111 }
112 // forward cost
113 if ((origin_cost->communication_without_parameter_ > EPS) &&
114 (origin_cost->communication_without_parameter_ <= comm_threshold)) {
115 origin_cost->communication_without_parameter_ = comm_const;
116 } else if (origin_cost->communication_without_parameter_ > comm_threshold) {
117 origin_cost->communication_without_parameter_ += comm_bias;
118 }
119 // total
120 if (origin_cost->communication_cost_ > EPS) {
121 origin_cost->communication_cost_ = origin_cost->communication_without_parameter_ + backward;
122 }
123 if (origin_cost->communication_with_partial_para_ > EPS) {
124 origin_cost->communication_with_partial_para_ = origin_cost->communication_without_parameter_ + gamma * backward;
125 }
126 }
127 }
128 } // namespace parallel
129 } // namespace mindspore
130