• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/auto_parallel/costmodel.h"
18 #include <cmath>
19 #include <numeric>
20 #include <utility>
21 #include "frontend/parallel/auto_parallel/graph_costmodel.h"
22 
23 namespace mindspore {
24 namespace parallel {
Simplify(CostPtrList * clist_ptrs)25 void Simplify(CostPtrList *clist_ptrs) {
26   const auto run_phase = CostModelContext::GetInstance()->run_phase();
27   if (run_phase == TRAINING_PHASE) {
28     // training phase
29     SimplifyForDecreasingCommunicationWithPartialPara(clist_ptrs);
30   } else {
31     // inference phase
32     SimplifyForDecreasingCommunicationForward(clist_ptrs);
33   }
34 }
SimplifyForDecreasingCommunicationForward(CostPtrList * clist_ptrs)35 void SimplifyForDecreasingCommunicationForward(CostPtrList *clist_ptrs) {
36   // Sort the cost_list with the computation_cost_ increasing, and communication_forward decreasing order. This method
37   // excludes the cost with greater computation_cost_ and greater communication_forward.
38   // E.g. clist_ptrs = {<100, 20>, <200, 10>, <300, 50>}. After this method, clist_ptrs = {<200, 10>, <100, 20>}
39   const auto simplify_cal = CostModelContext::GetInstance()->costmodel_simplify_cal();
40   if (!simplify_cal) {
41     return;
42   }
43   MS_EXCEPTION_IF_NULL(clist_ptrs);
44   std::vector<size_t> id(clist_ptrs->size());
45   std::iota(id.begin(), id.end(), size_t(0));
46   std::sort(id.begin(), id.end(), [&clist_ptrs](size_t x, size_t y) {
47     return clist_ptrs->at(x)->computation_cost_ < clist_ptrs->at(y)->computation_cost_;
48   });
49   CostPtrList ret;
50   for (size_t i = 0; i < clist_ptrs->size(); ++i) {
51     if ((ret.size() == size_t(0)) ||
52         (clist_ptrs->at(id[i])->communication_forward_ < ret.back()->communication_forward_)) {
53       ret.emplace_back(std::move(clist_ptrs->at(id[i])));
54     }
55   }
56   *clist_ptrs = std::move(ret);
57 }
58 
SimplifyForDecreasingCommunicationWithPartialPara(CostPtrList * clist_ptrs)59 void SimplifyForDecreasingCommunicationWithPartialPara(CostPtrList *clist_ptrs) {
60   // Sort the cost_list with the computation_cost_ increasing, and communication_with_partial_para_cost decreasing
61   // order. This method excludes the cost with greater computation_cost_ and greater communication_without_para_cost.
62   const auto simplify_cal = CostModelContext::GetInstance()->costmodel_simplify_cal();
63   if (!simplify_cal) {
64     return;
65   }
66   MS_EXCEPTION_IF_NULL(clist_ptrs);
67   std::vector<size_t> id(clist_ptrs->size());
68   std::iota(id.begin(), id.end(), size_t(0));
69   std::sort(id.begin(), id.end(), [&clist_ptrs](size_t x, size_t y) {
70     return clist_ptrs->at(x)->computation_cost_ < clist_ptrs->at(y)->computation_cost_;
71   });
72   CostPtrList ret;
73   for (size_t i = 0; i < clist_ptrs->size(); ++i) {
74     if ((ret.size() == size_t(0)) ||
75         (clist_ptrs->at(id[i])->communication_with_partial_para_ < ret.back()->communication_with_partial_para_)) {
76       ret.emplace_back(std::move(clist_ptrs->at(id[i])));
77     }
78   }
79   *clist_ptrs = std::move(ret);
80 }
81 
RefineForPracticalCost(const CostPtr & origin_cost,bool is_redistribution)82 void RefineForPracticalCost(const CostPtr &origin_cost, bool is_redistribution) {
83   MS_EXCEPTION_IF_NULL(origin_cost);
84   const auto comm_threshold = CostModelContext::GetInstance()->costmodel_communi_threshold();
85   const auto comm_const = CostModelContext::GetInstance()->costmodel_communi_const();
86   const auto comm_bias = CostModelContext::GetInstance()->costmodel_communi_bias();
87   const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
88   if (is_redistribution) {
89     // Redistribution cost
90     if ((origin_cost->communication_redis_forward_ > EPS) &&
91         (origin_cost->communication_redis_forward_ <= comm_threshold)) {
92       origin_cost->communication_redis_forward_ = comm_const;
93     } else if (origin_cost->communication_redis_forward_ > comm_threshold) {
94       origin_cost->communication_redis_forward_ += comm_bias;
95     }
96     if ((origin_cost->communication_redis_backward_ > EPS) &&
97         (origin_cost->communication_redis_backward_ <= comm_threshold)) {
98       origin_cost->communication_redis_backward_ = comm_const;
99     } else if (origin_cost->communication_redis_backward_ > comm_threshold) {
100       origin_cost->communication_redis_backward_ += comm_bias;
101     }
102     origin_cost->communication_cost_ =
103       origin_cost->communication_redis_forward_ + origin_cost->communication_redis_backward_;
104     origin_cost->communication_without_parameter_ = origin_cost->communication_cost_;
105     origin_cost->communication_with_partial_para_ = origin_cost->communication_cost_;
106   } else {
107     // Operator cost
108     double backward = 0.0;
109     if (std::abs(origin_cost->communication_cost_ - origin_cost->communication_without_parameter_) > EPS) {
110       backward = origin_cost->communication_cost_ - origin_cost->communication_without_parameter_;
111     }
112     // forward cost
113     if ((origin_cost->communication_without_parameter_ > EPS) &&
114         (origin_cost->communication_without_parameter_ <= comm_threshold)) {
115       origin_cost->communication_without_parameter_ = comm_const;
116     } else if (origin_cost->communication_without_parameter_ > comm_threshold) {
117       origin_cost->communication_without_parameter_ += comm_bias;
118     }
119     // total
120     if (origin_cost->communication_cost_ > EPS) {
121       origin_cost->communication_cost_ = origin_cost->communication_without_parameter_ + backward;
122     }
123     if (origin_cost->communication_with_partial_para_ > EPS) {
124       origin_cost->communication_with_partial_para_ = origin_cost->communication_without_parameter_ + gamma * backward;
125     }
126   }
127 }
128 }  // namespace parallel
129 }  // namespace mindspore
130