1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_STRATEGY_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_STRATEGY_H_
19
20 #include <cstdint>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25
26 #include "frontend/parallel/device_matrix.h"
27
28 namespace mindspore {
29 namespace parallel {
30 constexpr int MIN_SLICE_NUM = 1;
31
32 using Dimensions = Shape;
33 using NewDimensions = ShapeBasePtr;
34 using Strategies = std::vector<Dimensions>;
35 using NewStrategies = std::vector<NewDimensions>;
36 class Strategy;
37 using StrategyPtr = std::shared_ptr<Strategy>;
38
39 class Strategy {
40 public:
Strategy(int64_t stage,Strategies inputs)41 Strategy(int64_t stage, Strategies inputs)
42 : stage_(stage), inputs_(std::move(inputs)), internal_size_(0), internal_stragies_() {}
Strategy(int64_t stage,NewStrategies inputs)43 Strategy(int64_t stage, NewStrategies inputs)
44 : stage_(stage), inputs_new_(std::move(inputs)), internal_size_(0), internal_stragies_() {}
45
Strategy(const Strategy & another_stra)46 Strategy(const Strategy &another_stra) : stage_(another_stra.GetInputStage()) {
47 inputs_ = another_stra.GetInputDim();
48 inputs_new_ = another_stra.GetInputNewDim();
49 internal_size_ = another_stra.GetInternalSize();
50 if (internal_size_ != 0) {
51 internal_stragies_ = another_stra.GetInternalStrategies();
52 } else {
53 internal_stragies_ = {};
54 }
55 }
56
57 Strategy &operator=(const Strategy &another_stra) {
58 if (this != &another_stra) {
59 inputs_ = another_stra.GetInputDim();
60 inputs_new_ = another_stra.GetInputNewDim();
61 internal_size_ = another_stra.GetInternalSize();
62 if (internal_size_ != 0) {
63 internal_stragies_ = another_stra.GetInternalStrategies();
64 } else {
65 internal_stragies_ = {};
66 }
67 }
68 return *this;
69 }
70
71 ~Strategy() = default;
GetInputNumber()72 size_t GetInputNumber() const {
73 if (inputs_new_.empty()) {
74 return inputs_.size();
75 } else {
76 return inputs_new_.size();
77 }
78 }
HasTupleInTupleStrategy()79 bool HasTupleInTupleStrategy() const { return !inputs_new_.empty(); }
GetInputDim()80 Strategies GetInputDim() const { return inputs_; }
GetInputNewDim()81 NewStrategies GetInputNewDim() const { return inputs_new_; }
GetInputStage()82 int64_t GetInputStage() const { return stage_; }
ExpandInputDimFromOneToTwo()83 void ExpandInputDimFromOneToTwo() {
84 if (inputs_new_.empty()) {
85 if (inputs_.size() == 1) {
86 inputs_.push_back(inputs_[0]);
87 }
88 } else {
89 if (inputs_new_.size() == 1) {
90 inputs_new_.push_back(inputs_new_[0]);
91 }
92 }
93 }
ResetInputs(const Strategies & input)94 void ResetInputs(const Strategies &input) { inputs_ = input; }
ResetInputs(const NewStrategies & input)95 void ResetInputs(const NewStrategies &input) { inputs_new_ = input; }
GetInternalStrategies()96 std::vector<StrategyPtr> GetInternalStrategies() const { return internal_stragies_; }
GetInternalSize()97 size_t GetInternalSize() const { return internal_size_; }
98
IsEqual(const StrategyPtr & another_stra)99 bool IsEqual(const StrategyPtr &another_stra) {
100 if (another_stra == nullptr) {
101 return false;
102 }
103
104 std::vector<Dimensions> squashed_inputs_stra;
105 std::vector<size_t> stra_size;
106 std::vector<Dimensions> in_squashed_inputs_stra;
107 std::vector<size_t> in_stra_size;
108 // Current stra is tuple in tuple or not
109 std::tie(squashed_inputs_stra, stra_size) = GetSquashedStraAndSize();
110 // Input stra is tuple in tuple or not
111 std::tie(in_squashed_inputs_stra, in_stra_size) = GetInSquashedStraAndSize(another_stra);
112 if ((stage_ != another_stra->GetInputStage()) || (squashed_inputs_stra != in_squashed_inputs_stra) ||
113 (stra_size != in_stra_size)) {
114 return false;
115 }
116
117 return true;
118 }
119
PartitionNum()120 int64_t PartitionNum() {
121 int64_t divergence = 1;
122 if (inputs_new_.empty()) {
123 for (size_t i = 0; i < inputs_.size(); ++i) {
124 for (size_t j = 0; j < inputs_[i].size(); ++j) {
125 divergence *= inputs_[i][j];
126 }
127 }
128 } else {
129 for (const auto &stra : inputs_new_) {
130 ObtainPartionNum(stra, &divergence);
131 }
132 }
133 return divergence;
134 }
135
Compare(const StrategyPtr & another_stra)136 bool Compare(const StrategyPtr &another_stra) {
137 if (this->PartitionNum() > another_stra->PartitionNum()) {
138 return true;
139 } else if (this->PartitionNum() < another_stra->PartitionNum()) {
140 return false;
141 }
142 std::vector<Dimensions> squashed_inputs_stra;
143 std::vector<size_t> stra_size;
144 std::vector<Dimensions> in_squashed_inputs_stra;
145 std::vector<size_t> in_stra_size;
146 std::tie(squashed_inputs_stra, stra_size) = GetSquashedStraAndSize();
147 std::tie(in_squashed_inputs_stra, in_stra_size) = GetInSquashedStraAndSize(another_stra);
148 return squashed_inputs_stra > in_squashed_inputs_stra;
149 }
150
151 // Include 'another_stra' into this strategy
CoverStrategy(const StrategyPtr & another_stra)152 void CoverStrategy(const StrategyPtr &another_stra) {
153 internal_stragies_.push_back(another_stra);
154 internal_size_++;
155 }
156
ToString()157 std::string ToString() const {
158 std::ostringstream oss;
159 oss << "[";
160 if (this->HasTupleInTupleStrategy()) {
161 for (size_t i = 0; i < this->GetInputNumber(); ++i) {
162 CovertStrategyToString(this->GetInputNewDim()[i], &oss);
163 if (i != this->GetInputNumber() - 1) {
164 oss << ", ";
165 }
166 }
167 } else {
168 for (size_t i = 0; i < this->GetInputNumber(); ++i) {
169 oss << "[";
170 for (size_t j = 0; j < this->GetInputDim()[i].size(); ++j) {
171 oss << std::to_string(this->GetInputDim()[i][j]);
172 if (j != this->GetInputDim()[i].size() - 1) {
173 oss << ", ";
174 }
175 }
176 oss << "]";
177 if (i != this->GetInputNumber() - 1) {
178 oss << ", ";
179 }
180 }
181 }
182 oss << "]";
183 return oss.str();
184 }
185
186 private:
187 const int64_t stage_;
188
189 // The size of Dimensions must be equal to inputs_ tensor dimension.
190 Strategies inputs_;
191 NewStrategies inputs_new_;
192 size_t internal_size_ = 0;
193 std::vector<StrategyPtr> internal_stragies_;
194
ObtainPartionNum(const NewDimensions & inputs_stra,int64_t * divergence)195 void ObtainPartionNum(const NewDimensions &inputs_stra, int64_t *divergence) {
196 if (inputs_stra->is_list()) {
197 for (size_t i = 0; i < inputs_stra->size(); ++i) {
198 ObtainPartionNum(inputs_stra->GetElement(SizeToLong(i)), divergence);
199 }
200 } else {
201 auto stra = inputs_stra->GetValue();
202 for (const auto &stra_value : stra) {
203 *divergence *= stra_value;
204 }
205 }
206 }
207
GetInSquashedStraAndSize(const StrategyPtr & inputs_stra)208 std::pair<std::vector<Dimensions>, std::vector<size_t>> GetInSquashedStraAndSize(const StrategyPtr &inputs_stra) {
209 std::vector<Dimensions> in_squashed_inputs_stra;
210 std::vector<size_t> in_stra_size;
211 if (inputs_stra->HasTupleInTupleStrategy()) {
212 auto local_stra = inputs_stra->GetInputNewDim();
213 for (const auto &stra : local_stra) {
214 auto all_stra = stra->GetAllElements();
215 in_squashed_inputs_stra.insert(in_squashed_inputs_stra.end(), all_stra.begin(), all_stra.end());
216 in_stra_size.push_back(stra->size());
217 }
218 } else {
219 in_squashed_inputs_stra = inputs_stra->GetInputDim();
220 for (size_t i = 0; i < in_squashed_inputs_stra.size(); ++i) {
221 in_stra_size.push_back(in_squashed_inputs_stra[i].size());
222 }
223 }
224 return std::make_pair(in_squashed_inputs_stra, in_stra_size);
225 }
226
GetSquashedStraAndSize()227 std::pair<std::vector<Dimensions>, std::vector<size_t>> GetSquashedStraAndSize() {
228 std::vector<Dimensions> squashed_inputs_stra;
229 std::vector<size_t> stra_size;
230 if (inputs_new_.empty()) {
231 squashed_inputs_stra = inputs_;
232 for (size_t i = 0; i < inputs_.size(); ++i) {
233 stra_size.push_back(inputs_[i].size());
234 }
235 } else {
236 for (const auto &stra : inputs_new_) {
237 auto all_stra = stra->GetAllElements();
238 squashed_inputs_stra.insert(squashed_inputs_stra.end(), all_stra.begin(), all_stra.end());
239 stra_size.push_back(stra->size());
240 }
241 }
242 return std::make_pair(squashed_inputs_stra, stra_size);
243 }
244
CovertStrategyToString(const NewDimensions & stra,std::ostringstream * oss)245 void CovertStrategyToString(const NewDimensions &stra, std::ostringstream *oss) const {
246 *oss << "[";
247 if (stra->is_list()) {
248 for (size_t i = 0; i < stra->size(); ++i) {
249 CovertStrategyToString(stra->GetElement(SizeToLong(i)), oss);
250 if (i != stra->size() - 1) {
251 *oss << ", ";
252 }
253 }
254 } else {
255 auto stra_value = stra->GetValue();
256 for (size_t i = 0; i < stra_value.size(); ++i) {
257 *oss << stra_value[i];
258 if (i != stra_value.size() - 1) {
259 *oss << ", ";
260 }
261 }
262 }
263 *oss << "]";
264 }
265 };
266
NewStrategy(const int64_t stage,const Strategies & inputs)267 inline StrategyPtr NewStrategy(const int64_t stage, const Strategies &inputs) {
268 return std::make_shared<Strategy>(stage, inputs);
269 }
NewStrategy(const int64_t stage,const NewStrategies & inputs)270 inline StrategyPtr NewStrategy(const int64_t stage, const NewStrategies &inputs) {
271 return std::make_shared<Strategy>(stage, inputs);
272 }
273 } // namespace parallel
274 } // namespace mindspore
275
276 #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_STRATEGY_H_
277