• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_STRATEGY_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_STRATEGY_H_
19 
20 #include <cstdint>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 #include "frontend/parallel/device_matrix.h"
27 
28 namespace mindspore {
29 namespace parallel {
30 constexpr int MIN_SLICE_NUM = 1;
31 
32 using Dimensions = Shape;
33 using NewDimensions = ShapeBasePtr;
34 using Strategies = std::vector<Dimensions>;
35 using NewStrategies = std::vector<NewDimensions>;
36 class Strategy;
37 using StrategyPtr = std::shared_ptr<Strategy>;
38 
39 class Strategy {
40  public:
Strategy(int64_t stage,Strategies inputs)41   Strategy(int64_t stage, Strategies inputs)
42       : stage_(stage), inputs_(std::move(inputs)), internal_size_(0), internal_stragies_() {}
Strategy(int64_t stage,NewStrategies inputs)43   Strategy(int64_t stage, NewStrategies inputs)
44       : stage_(stage), inputs_new_(std::move(inputs)), internal_size_(0), internal_stragies_() {}
45 
Strategy(const Strategy & another_stra)46   Strategy(const Strategy &another_stra) : stage_(another_stra.GetInputStage()) {
47     inputs_ = another_stra.GetInputDim();
48     inputs_new_ = another_stra.GetInputNewDim();
49     internal_size_ = another_stra.GetInternalSize();
50     if (internal_size_ != 0) {
51       internal_stragies_ = another_stra.GetInternalStrategies();
52     } else {
53       internal_stragies_ = {};
54     }
55   }
56 
57   Strategy &operator=(const Strategy &another_stra) {
58     if (this != &another_stra) {
59       inputs_ = another_stra.GetInputDim();
60       inputs_new_ = another_stra.GetInputNewDim();
61       internal_size_ = another_stra.GetInternalSize();
62       if (internal_size_ != 0) {
63         internal_stragies_ = another_stra.GetInternalStrategies();
64       } else {
65         internal_stragies_ = {};
66       }
67     }
68     return *this;
69   }
70 
71   ~Strategy() = default;
GetInputNumber()72   size_t GetInputNumber() const {
73     if (inputs_new_.empty()) {
74       return inputs_.size();
75     } else {
76       return inputs_new_.size();
77     }
78   }
HasTupleInTupleStrategy()79   bool HasTupleInTupleStrategy() const { return !inputs_new_.empty(); }
GetInputDim()80   Strategies GetInputDim() const { return inputs_; }
GetInputNewDim()81   NewStrategies GetInputNewDim() const { return inputs_new_; }
GetInputStage()82   int64_t GetInputStage() const { return stage_; }
ExpandInputDimFromOneToTwo()83   void ExpandInputDimFromOneToTwo() {
84     if (inputs_new_.empty()) {
85       if (inputs_.size() == 1) {
86         inputs_.push_back(inputs_[0]);
87       }
88     } else {
89       if (inputs_new_.size() == 1) {
90         inputs_new_.push_back(inputs_new_[0]);
91       }
92     }
93   }
ResetInputs(const Strategies & input)94   void ResetInputs(const Strategies &input) { inputs_ = input; }
ResetInputs(const NewStrategies & input)95   void ResetInputs(const NewStrategies &input) { inputs_new_ = input; }
GetInternalStrategies()96   std::vector<StrategyPtr> GetInternalStrategies() const { return internal_stragies_; }
GetInternalSize()97   size_t GetInternalSize() const { return internal_size_; }
98 
IsEqual(const StrategyPtr & another_stra)99   bool IsEqual(const StrategyPtr &another_stra) {
100     if (another_stra == nullptr) {
101       return false;
102     }
103 
104     std::vector<Dimensions> squashed_inputs_stra;
105     std::vector<size_t> stra_size;
106     std::vector<Dimensions> in_squashed_inputs_stra;
107     std::vector<size_t> in_stra_size;
108     // Current stra is tuple in tuple or not
109     std::tie(squashed_inputs_stra, stra_size) = GetSquashedStraAndSize();
110     // Input stra is tuple in tuple or not
111     std::tie(in_squashed_inputs_stra, in_stra_size) = GetInSquashedStraAndSize(another_stra);
112     if ((stage_ != another_stra->GetInputStage()) || (squashed_inputs_stra != in_squashed_inputs_stra) ||
113         (stra_size != in_stra_size)) {
114       return false;
115     }
116 
117     return true;
118   }
119 
PartitionNum()120   int64_t PartitionNum() {
121     int64_t divergence = 1;
122     if (inputs_new_.empty()) {
123       for (size_t i = 0; i < inputs_.size(); ++i) {
124         for (size_t j = 0; j < inputs_[i].size(); ++j) {
125           divergence *= inputs_[i][j];
126         }
127       }
128     } else {
129       for (const auto &stra : inputs_new_) {
130         ObtainPartionNum(stra, &divergence);
131       }
132     }
133     return divergence;
134   }
135 
Compare(const StrategyPtr & another_stra)136   bool Compare(const StrategyPtr &another_stra) {
137     if (this->PartitionNum() > another_stra->PartitionNum()) {
138       return true;
139     } else if (this->PartitionNum() < another_stra->PartitionNum()) {
140       return false;
141     }
142     std::vector<Dimensions> squashed_inputs_stra;
143     std::vector<size_t> stra_size;
144     std::vector<Dimensions> in_squashed_inputs_stra;
145     std::vector<size_t> in_stra_size;
146     std::tie(squashed_inputs_stra, stra_size) = GetSquashedStraAndSize();
147     std::tie(in_squashed_inputs_stra, in_stra_size) = GetInSquashedStraAndSize(another_stra);
148     return squashed_inputs_stra > in_squashed_inputs_stra;
149   }
150 
151   // Include 'another_stra' into this strategy
CoverStrategy(const StrategyPtr & another_stra)152   void CoverStrategy(const StrategyPtr &another_stra) {
153     internal_stragies_.push_back(another_stra);
154     internal_size_++;
155   }
156 
ToString()157   std::string ToString() const {
158     std::ostringstream oss;
159     oss << "[";
160     if (this->HasTupleInTupleStrategy()) {
161       for (size_t i = 0; i < this->GetInputNumber(); ++i) {
162         CovertStrategyToString(this->GetInputNewDim()[i], &oss);
163         if (i != this->GetInputNumber() - 1) {
164           oss << ", ";
165         }
166       }
167     } else {
168       for (size_t i = 0; i < this->GetInputNumber(); ++i) {
169         oss << "[";
170         for (size_t j = 0; j < this->GetInputDim()[i].size(); ++j) {
171           oss << std::to_string(this->GetInputDim()[i][j]);
172           if (j != this->GetInputDim()[i].size() - 1) {
173             oss << ", ";
174           }
175         }
176         oss << "]";
177         if (i != this->GetInputNumber() - 1) {
178           oss << ", ";
179         }
180       }
181     }
182     oss << "]";
183     return oss.str();
184   }
185 
186  private:
187   const int64_t stage_;
188 
189   // The size of Dimensions must be equal to inputs_ tensor dimension.
190   Strategies inputs_;
191   NewStrategies inputs_new_;
192   size_t internal_size_ = 0;
193   std::vector<StrategyPtr> internal_stragies_;
194 
ObtainPartionNum(const NewDimensions & inputs_stra,int64_t * divergence)195   void ObtainPartionNum(const NewDimensions &inputs_stra, int64_t *divergence) {
196     if (inputs_stra->is_list()) {
197       for (size_t i = 0; i < inputs_stra->size(); ++i) {
198         ObtainPartionNum(inputs_stra->GetElement(SizeToLong(i)), divergence);
199       }
200     } else {
201       auto stra = inputs_stra->GetValue();
202       for (const auto &stra_value : stra) {
203         *divergence *= stra_value;
204       }
205     }
206   }
207 
GetInSquashedStraAndSize(const StrategyPtr & inputs_stra)208   std::pair<std::vector<Dimensions>, std::vector<size_t>> GetInSquashedStraAndSize(const StrategyPtr &inputs_stra) {
209     std::vector<Dimensions> in_squashed_inputs_stra;
210     std::vector<size_t> in_stra_size;
211     if (inputs_stra->HasTupleInTupleStrategy()) {
212       auto local_stra = inputs_stra->GetInputNewDim();
213       for (const auto &stra : local_stra) {
214         auto all_stra = stra->GetAllElements();
215         in_squashed_inputs_stra.insert(in_squashed_inputs_stra.end(), all_stra.begin(), all_stra.end());
216         in_stra_size.push_back(stra->size());
217       }
218     } else {
219       in_squashed_inputs_stra = inputs_stra->GetInputDim();
220       for (size_t i = 0; i < in_squashed_inputs_stra.size(); ++i) {
221         in_stra_size.push_back(in_squashed_inputs_stra[i].size());
222       }
223     }
224     return std::make_pair(in_squashed_inputs_stra, in_stra_size);
225   }
226 
GetSquashedStraAndSize()227   std::pair<std::vector<Dimensions>, std::vector<size_t>> GetSquashedStraAndSize() {
228     std::vector<Dimensions> squashed_inputs_stra;
229     std::vector<size_t> stra_size;
230     if (inputs_new_.empty()) {
231       squashed_inputs_stra = inputs_;
232       for (size_t i = 0; i < inputs_.size(); ++i) {
233         stra_size.push_back(inputs_[i].size());
234       }
235     } else {
236       for (const auto &stra : inputs_new_) {
237         auto all_stra = stra->GetAllElements();
238         squashed_inputs_stra.insert(squashed_inputs_stra.end(), all_stra.begin(), all_stra.end());
239         stra_size.push_back(stra->size());
240       }
241     }
242     return std::make_pair(squashed_inputs_stra, stra_size);
243   }
244 
CovertStrategyToString(const NewDimensions & stra,std::ostringstream * oss)245   void CovertStrategyToString(const NewDimensions &stra, std::ostringstream *oss) const {
246     *oss << "[";
247     if (stra->is_list()) {
248       for (size_t i = 0; i < stra->size(); ++i) {
249         CovertStrategyToString(stra->GetElement(SizeToLong(i)), oss);
250         if (i != stra->size() - 1) {
251           *oss << ", ";
252         }
253       }
254     } else {
255       auto stra_value = stra->GetValue();
256       for (size_t i = 0; i < stra_value.size(); ++i) {
257         *oss << stra_value[i];
258         if (i != stra_value.size() - 1) {
259           *oss << ", ";
260         }
261       }
262     }
263     *oss << "]";
264   }
265 };
266 
NewStrategy(const int64_t stage,const Strategies & inputs)267 inline StrategyPtr NewStrategy(const int64_t stage, const Strategies &inputs) {
268   return std::make_shared<Strategy>(stage, inputs);
269 }
NewStrategy(const int64_t stage,const NewStrategies & inputs)270 inline StrategyPtr NewStrategy(const int64_t stage, const NewStrategies &inputs) {
271   return std::make_shared<Strategy>(stage, inputs);
272 }
273 }  // namespace parallel
274 }  // namespace mindspore
275 
276 #endif  // MINDSPORE_CCSRC_FRONTEND_PARALLEL_STRATEGY_H_
277