1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/tensor_layout/tensor_layout.h"
18 #include <iostream>
19 #include <utility>
20 #include "utils/ms_utils.h"
21 #include "ir/value.h"
22 #include "frontend/parallel/device_matrix.h"
23 #include "frontend/parallel/status.h"
24 #include "frontend/parallel/context.h"
25 #include "frontend/parallel/tensor_layout/shape_util.h"
26 #include "utils/log_adapter.h"
27
28 namespace mindspore {
29 namespace parallel {
ToString() const30 std::string TensorLayout::ToString() const { return StandardToString() + OriginToString(); }
31
StandardToString() const32 std::string TensorLayout::StandardToString() const {
33 std::ostringstream buffer;
34 buffer << std::endl << std::string("device arrangement = " + device_arrangement_.ToString());
35 buffer << std::endl << std::string("tensor map = " + tensor_map_.ToString());
36 buffer << std::endl << std::string("tensor shape = " + tensor_shape_.ToString());
37 return buffer.str();
38 }
39
OriginToString() const40 std::string TensorLayout::OriginToString() const {
41 std::ostringstream buffer;
42 buffer << std::endl << std::string("device arrangement origin = " + device_arrangement_origin_.ToString());
43 buffer << std::endl << std::string("tensor map origin = " + tensor_map_origin_.ToString());
44 buffer << std::endl << std::string("tensor shape origin = " + tensor_shape_origin_.ToString());
45 return buffer.str();
46 }
47
Init(const Arrangement & device_arrangement,const Map & tensor_map,const Arrangement & tensor_shape)48 Status TensorLayout::Init(const Arrangement &device_arrangement, const Map &tensor_map,
49 const Arrangement &tensor_shape) {
50 device_arrangement_origin_ = device_arrangement;
51 tensor_map_origin_ = tensor_map;
52 tensor_shape_origin_ = tensor_shape;
53 device_arrangement_ = device_arrangement;
54 tensor_map_ = tensor_map;
55 tensor_shape_ = tensor_shape;
56 if (IsValidTensorLayout()) {
57 MS_LOG(DEBUG) << "valid origin tensor layout " << this->OriginToString();
58 RemoveElementEqualToOneInDeviceArrangement();
59 MS_LOG(DEBUG) << "standard tensor layout " << this->StandardToString();
60 return Status::SUCCESS;
61 } else {
62 if (layout_transfer_) {
63 MS_LOG(DEBUG) << "invalid origin tensor layout " << this->OriginToString();
64 } else {
65 MS_LOG(ERROR) << "invalid origin tensor layout " << this->OriginToString();
66 }
67 return Status::FAILED;
68 }
69 }
70
InitFromVector(const Shape & device_arrangement,const Shape & tensor_map,const Shape & tensor_shape)71 Status TensorLayout::InitFromVector(const Shape &device_arrangement, const Shape &tensor_map,
72 const Shape &tensor_shape) {
73 if (device_arrangement_origin_.Init(device_arrangement) != SUCCESS) {
74 return FAILED;
75 }
76 if (tensor_map_origin_.Init(tensor_map) != SUCCESS) {
77 return FAILED;
78 }
79 if (tensor_shape_origin_.Init(tensor_shape) != SUCCESS) {
80 return FAILED;
81 }
82 if (Init(device_arrangement_origin_, tensor_map_origin_, tensor_shape_origin_) != SUCCESS) {
83 return FAILED;
84 }
85 return SUCCESS;
86 }
87
IsValidTensorLayout() const88 bool TensorLayout::IsValidTensorLayout() const {
89 if (tensor_map_origin_.GetMaxItem() >= static_cast<int64_t>(device_arrangement_origin_.GetDimSize())) {
90 MS_LOG(ERROR) << "the max element in tensor_map_origin_ must be smaller than device_arrangement_origin_ size!";
91 return false;
92 }
93 if (tensor_map_origin_.GetDimSize() != tensor_shape_origin_.GetDimSize()) {
94 MS_LOG(ERROR) << "tensor_map_origin_ size must be equal to tensor_shape_origin_ size!";
95 return false;
96 }
97 if (!TensorShapeDimensionIsDividedBySplitDeviceDimension()) {
98 if (layout_transfer_) {
99 MS_LOG(DEBUG) << "TensorShapeDimensionIsDividedBySplitDeviceDimension failed!";
100 } else {
101 MS_LOG(ERROR) << "TensorShapeDimensionIsDividedBySplitDeviceDimension failed!";
102 }
103 return false;
104 }
105 return true;
106 }
107
TensorShapeDimensionIsDividedBySplitDeviceDimension() const108 bool TensorLayout::TensorShapeDimensionIsDividedBySplitDeviceDimension() const {
109 for (uint64_t i = 0; i < tensor_map_.GetDimSize(); i++) {
110 if (tensor_map_.GetDimByIdx(i) != -1) {
111 int64_t divisor = GetSliceNumByTensorDimensionIndex(i);
112 if (divisor == 0) {
113 MS_LOG(ERROR) << "GetSliceNumByTensorDimensionIndex is 0";
114 return false;
115 }
116 if (tensor_shape_.GetDimByIdx(i) % divisor != 0) {
117 return false;
118 }
119 }
120 }
121 return true;
122 }
123
RemoveElementEqualToOneInDeviceArrangement()124 void TensorLayout::RemoveElementEqualToOneInDeviceArrangement() {
125 Shape device_arrangement_shape;
126 Shape tensor_map_shape = tensor_map_origin_.array();
127 size_t dev_num = device_arrangement_origin_.GetDimSize();
128 size_t dev_num_left = device_arrangement_origin_.GetDimSize();
129 for (size_t i = 0; i < dev_num; i++) {
130 if (device_arrangement_origin_.GetDimByIdx(i) == 1) {
131 int64_t idx = GetTensorDimensionIndexByDeviceDimensionIndex(static_cast<int64_t>(dev_num - 1 - i));
132 if (idx != -1) {
133 tensor_map_shape[static_cast<uint64_t>(idx)] = -1;
134 }
135 for (auto &value : tensor_map_shape) {
136 if (value >= SizeToLong(dev_num_left) - 1 - static_cast<int64_t>(i)) {
137 value--;
138 }
139 }
140 continue;
141 }
142 device_arrangement_shape.push_back(device_arrangement_origin_.GetDimByIdx(i));
143 }
144 (void)device_arrangement_.Init(device_arrangement_shape);
145 (void)tensor_map_.Init(tensor_map_shape);
146 tensor_shape_ = tensor_shape_origin_;
147 }
148
149 // if idx is not in tensor_map, return -1
GetTensorDimensionIndexByDeviceDimensionIndex(int64_t idx) const150 int64_t TensorLayout::GetTensorDimensionIndexByDeviceDimensionIndex(int64_t idx) const {
151 return tensor_map_.GetIndexByValue(idx);
152 }
153
154 // tensor_map_.GetDimByIdx(idx) should not be -1
GetSliceDeviceDimensionByTensorDimensionIndex(uint64_t idx) const155 int64_t TensorLayout::GetSliceDeviceDimensionByTensorDimensionIndex(uint64_t idx) const {
156 return static_cast<int64_t>(device_arrangement_.GetDimSize()) - 1 - tensor_map_.GetDimByIdx(idx);
157 }
158
159 // tensor_map_.GetDimByIdx(idx) should not be -1
GetSliceNumByTensorDimensionIndex(uint64_t idx) const160 int64_t TensorLayout::GetSliceNumByTensorDimensionIndex(uint64_t idx) const {
161 return device_arrangement_.GetDimByIdx(static_cast<uint64_t>(GetSliceDeviceDimensionByTensorDimensionIndex(idx)));
162 }
163
ExpandTensorShape(const Arrangement & expanded_shape) const164 std::shared_ptr<TensorLayout> TensorLayout::ExpandTensorShape(const Arrangement &expanded_shape) const {
165 std::shared_ptr<Arrangement> expanded_arrangement_ptr = ComputeArrangementByExpandedShape(expanded_shape);
166 if (expanded_arrangement_ptr == nullptr) {
167 return nullptr;
168 }
169 std::shared_ptr<TensorLayout> temp_tensor_layout_ptr = ExpandDeviceArrangement(*expanded_arrangement_ptr);
170 if (temp_tensor_layout_ptr == nullptr) {
171 return nullptr;
172 }
173 return temp_tensor_layout_ptr->ExpandTensorShapeWithoutExtendDeviceArrangement(expanded_shape);
174 }
175
176 /*
177 * example1:
178 * in_device_arrangement = [8, 4],
179 * in_tensor_map = [1, 0],
180 * in_tensor_shape = [512, 1024],
181 * out_tensor_shape = [128, 4, 2, 512],
182 * =>
183 * out_device_arrangement = [8, 2, 2]
184 */
ComputeArrangementByExpandedShape(const Arrangement & tensor_shape) const185 std::shared_ptr<Arrangement> TensorLayout::ComputeArrangementByExpandedShape(const Arrangement &tensor_shape) const {
186 std::shared_ptr<std::vector<Arrangement>> expand_list_ptr = tensor_shape_.GetExpandShapeList(tensor_shape);
187 if (expand_list_ptr == nullptr) {
188 return nullptr;
189 }
190 std::vector<Arrangement> re_map_expand_list;
191 Arrangement empty_arrangement;
192 for (int64_t i = static_cast<int64_t>(device_arrangement_.GetDimSize()) - 1; i >= 0; i--) {
193 if (tensor_map_.GetIndexByValue(i) < 0) {
194 re_map_expand_list.push_back(empty_arrangement);
195 } else {
196 re_map_expand_list.push_back((*expand_list_ptr)[LongToUlong(tensor_map_.GetIndexByValue(i))]);
197 }
198 }
199 std::shared_ptr<Arrangement> new_arrangement_ptr =
200 device_arrangement_.GetExpandedShapeByExpandListRemoveLeft(re_map_expand_list);
201 return new_arrangement_ptr;
202 }
203
204 /*
205 * example1:
206 * in_device_arrangement = [8, 4],
207 * in_tensor_map = [1, 0],
208 * in_tensor_shape = [512, 1024],
209 * out_tensor_shape = [8, 64, 4, 256]
210 * =>
211 * out_device_arrangement = [8, 4],
212 * out_tensor_map = [1, -1, 0, -1],
213 */
ExpandTensorShapeWithoutExtendDeviceArrangement(const Arrangement & expanded_shape) const214 std::shared_ptr<TensorLayout> TensorLayout::ExpandTensorShapeWithoutExtendDeviceArrangement(
215 const Arrangement &expanded_shape) const {
216 std::shared_ptr<std::pair<std::vector<Arrangement>, Arrangement>> expand_list_pair_ptr =
217 tensor_shape_.GetExpandShapeListPair(expanded_shape);
218 if (expand_list_pair_ptr == nullptr) {
219 return nullptr;
220 }
221 std::shared_ptr<Map> tensor_map_new_ptr = tensor_map_.ExpandMapByNone(expand_list_pair_ptr->second);
222 if (tensor_map_new_ptr == nullptr) {
223 return nullptr;
224 }
225 TensorLayout tensor_layout_new;
226 tensor_layout_new.set_layout_transfer(true);
227 Status status = tensor_layout_new.Init(device_arrangement_, *tensor_map_new_ptr, expanded_shape);
228 if (status != Status::SUCCESS) {
229 return nullptr;
230 }
231 return std::make_shared<TensorLayout>(tensor_layout_new);
232 }
233
234 /*
235 * example1:
236 * in_device_arrangement = [8, 4],
237 * in_tensor_map = [1, 0],
238 * in_tensor_shape = [512, 1024],
239 * out_device_arrangement = [4, 2, 2, 2]
240 * =>
241 * out_tensor_map = [3, 2, 1, 0],
242 * out_tensor_shape = [4, 128, 2, 512]
243 *
244 * example2:
245 * in_device_arrangement = [8, 4],
246 * in_tensor_map = [0, 1],
247 * in_tensor_shape = [512, 1024],
248 * out_device_arrangement = [4, 2, 2, 2]
249 * =>
250 * out_tensor_map = [1, 0, 3, 2],
251 * out_tensor_shape = [2, 256, 4, 256]
252 *
253 * example3:
254 * in_device_arrangement = [8, 4],
255 * in_tensor_map = [1, -1],
256 * in_tensor_shape = [512, 1024],
257 * out_device_arrangement = [4, 2, 2, 2]
258 * =>
259 * out_tensor_map = [3, 2, -1],
260 * out_tensor_shape = [4, 128, 1024]
261 *
262 * example4:
263 * in_device_arrangement = [8, 4],
264 * in_tensor_map = [0, 1],
265 * in_tensor_shape = [512, 1024],
266 * out_device_arrangement = [4, 2, 4]
267 * =>
268 * out_tensor_map = [0, 2, 1],
269 * out_tensor_shape = [512, 4, 256]
270 */
ExpandDeviceArrangement(const Arrangement & expanded_arrangement) const271 std::shared_ptr<TensorLayout> TensorLayout::ExpandDeviceArrangement(const Arrangement &expanded_arrangement) const {
272 std::shared_ptr<std::pair<std::vector<Arrangement>, Arrangement>> expand_list_pair_ptr =
273 device_arrangement_.GetExpandShapeListPair(expanded_arrangement);
274 if (expand_list_pair_ptr == nullptr) {
275 return nullptr;
276 }
277 std::shared_ptr<Map> tensor_map_new_ptr = tensor_map_.ExpandMapByDecreaseNumber(expand_list_pair_ptr->second);
278 if (tensor_map_new_ptr == nullptr) {
279 return nullptr;
280 }
281 std::shared_ptr<std::vector<Arrangement>> re_map_shape_list_ptr =
282 tensor_map_.ReMapVector(expand_list_pair_ptr->first);
283 if (re_map_shape_list_ptr == nullptr) {
284 return nullptr;
285 }
286 std::shared_ptr<Arrangement> tensor_shape_new_ptr =
287 tensor_shape_.GetExpandedShapeByExpandListReserveLeft(*re_map_shape_list_ptr);
288 if (tensor_shape_new_ptr == nullptr) {
289 return nullptr;
290 }
291 TensorLayout tensor_layout_new;
292 Status status = tensor_layout_new.Init(expanded_arrangement, *tensor_map_new_ptr, *tensor_shape_new_ptr);
293 if (status != Status::SUCCESS) {
294 return nullptr;
295 }
296 return std::make_shared<TensorLayout>(tensor_layout_new);
297 }
298
TensorShapeCanBeExpanded(const Arrangement & expand_shape) const299 bool TensorLayout::TensorShapeCanBeExpanded(const Arrangement &expand_shape) const {
300 Shape in_expand_shape_shape;
301 Status status = ExpandShape(tensor_shape_.array(), expand_shape.array(), &in_expand_shape_shape);
302 if (status != Status::SUCCESS) {
303 return false;
304 }
305 return (in_expand_shape_shape == tensor_shape_.array());
306 }
307
ComputeExpandedTensorShape(const Arrangement & expand_shape) const308 std::shared_ptr<Arrangement> TensorLayout::ComputeExpandedTensorShape(const Arrangement &expand_shape) const {
309 Shape in_expand_shape_shape;
310 Status status = ExpandShape(tensor_shape_.array(), expand_shape.array(), &in_expand_shape_shape);
311 if (status != Status::SUCCESS) {
312 return nullptr;
313 }
314 Arrangement expanded_shape;
315 status = expanded_shape.Init(in_expand_shape_shape);
316 if (status != Status::SUCCESS) {
317 return nullptr;
318 }
319 return std::make_shared<Arrangement>(expanded_shape);
320 }
321
slice_shape() const322 Arrangement TensorLayout::slice_shape() const {
323 Shape shape;
324 for (size_t index = 0; index < tensor_map_.GetDimSize(); index++) {
325 int64_t dim = tensor_map_.GetDimByIdx(index);
326 int64_t num = tensor_shape_.GetDimByIdx(index);
327 if (dim == -1) {
328 shape.push_back(num);
329 } else {
330 int64_t divisor = device_arrangement_.GetDimByReverseIdx(LongToUlong(dim));
331 shape.push_back(num / divisor);
332 }
333 }
334 Arrangement new_tensor_shape;
335 if (new_tensor_shape.Init(shape) == Status::FAILED) {
336 ValuePtr ptr = MakeValue(shape);
337 MS_LOG(EXCEPTION) << "Can't get slice shape when initialize a new shape " << ptr->ToString();
338 } else {
339 return new_tensor_shape;
340 }
341 }
342
UpdateTensorMap(size_t index,int64_t value)343 Status TensorLayout::UpdateTensorMap(size_t index, int64_t value) {
344 if (index >= tensor_map_.GetDimSize()) {
345 MS_LOG(ERROR) << "Index is out of the size of the tensor map!";
346 return Status::FAILED;
347 }
348 auto shape = tensor_map_.array();
349 shape[index] = value;
350 if (tensor_map_.Init(shape) == Status::FAILED) {
351 MS_LOG(ERROR) << "Update tensor map failed!";
352 return Status::FAILED;
353 }
354 return Status::SUCCESS;
355 }
356
operator ==(const TensorLayout & t1) const357 bool TensorLayout::operator==(const TensorLayout &t1) const {
358 return (IsSameDeviceArrangement(t1) && IsSameTensorMap(t1) && IsSameTensorShape(t1));
359 }
360
operator !=(const TensorLayout & t1) const361 bool TensorLayout::operator!=(const TensorLayout &t1) const {
362 return !(IsSameDeviceArrangement(t1) && IsSameTensorMap(t1) && IsSameTensorShape(t1));
363 }
364
365 /*
366 * remove elements equal to 1 in tensor_shape, if all elements are 1, squeeze the tensor_shape to [ 1 ]
367 * example 1:
368 * original tensor layout:
369 * device arrangement = [ 8 ]
370 * tensor map = [ 0 -1 -1 -1 ]
371 * tensor shape = [ 128 64 1 1 ]
372 * return tensor layout:
373 * device arrangement = [ 8 ]
374 * tensor map = [ 0 -1 ]
375 * tensor shape = [ 128 64 ]
376 *
377 * example 2:
378 * original tensor layout:
379 * device arrangement = [ 8 ]
380 * tensor map = [ -1 -1 -1 -1 ]
381 * tensor shape = [ 1 1 1 1 ]
382 * return tensor layout:
383 * device arrangement = [ 8 ]
384 * tensor map = [ -1 ]
385 * tensor shape = [ 1 ]
386 */
SqueezeShape() const387 TensorLayout TensorLayout::SqueezeShape() const {
388 TensorLayout out;
389 Map out_map;
390 Arrangement out_shape;
391 if (tensor_shape_.size() == 1) {
392 (void)out_map.Init({MAP_NONE});
393 (void)out_shape.Init({1});
394 (void)out.Init(device_arrangement_, out_map, out_shape);
395 return out;
396 }
397 std::vector<size_t> squeeze_list = tensor_shape_.GetSqueezeIdx();
398 if (!tensor_map_.CheckNoneByIdxList(squeeze_list)) {
399 MS_LOG(ERROR) << "CheckNoneByIdxList failed, this may not happen under current situation";
400 return *this;
401 }
402 out_shape = tensor_shape_.GetSqueezeArrangement();
403 out_map = tensor_map_.SqueezeMapByIdxList(squeeze_list);
404 (void)out.Init(device_arrangement_, out_map, out_shape);
405 return out;
406 }
407
TransferRepeatLayout() const408 TensorLayout TensorLayout::TransferRepeatLayout() const {
409 Shape dev_mat(device_arrangement_origin_.array());
410 Shape tensor_map(tensor_map_origin_.GetDimSize(), -1);
411 Shape tensor_shape(tensor_shape_origin_.array());
412 TensorLayout repeat;
413 repeat.InitFromVector(dev_mat, tensor_map, tensor_shape);
414 return repeat;
415 }
416
417 // Generate a totally shard tensor slice shape for parallel optimizer
GenerateOptShardSliceShape()418 Status TensorLayout::GenerateOptShardSliceShape() {
419 MS_LOG(INFO) << "layout for GetOptShardSliceShape is " << StandardToString();
420 Shape dev_max = device_arrangement_.array();
421 Shape tensor_map = tensor_map_.array();
422 Shape repeated_dev;
423 for (size_t i = 0; i < dev_max.size(); i++) {
424 if (tensor_map_.GetIndexByValue(static_cast<int64_t>(i)) == MAP_NONE) {
425 repeated_dev.push_back(dev_max[dev_max.size() - 1 - i]);
426 dev_max[dev_max.size() - 1 - i] = 1;
427 }
428 }
429 if (repeated_dev.empty()) {
430 MS_LOG(INFO) << "Tensor is totally shard already.";
431 return Status::FAILED;
432 }
433 int64_t repeated_num =
434 std::accumulate(repeated_dev.begin(), repeated_dev.end(), static_cast<int64_t>(1), std::multiplies<int64_t>());
435 int64_t split_num;
436 int64_t optimizer_weight_shard_size = ParallelContext::GetInstance()->optimizer_weight_shard_size();
437 if (optimizer_weight_shard_size != -1) {
438 repeated_num = optimizer_weight_shard_size;
439 }
440 if (tensor_map[0] == MAP_NONE) {
441 split_num = repeated_num;
442 } else {
443 split_num = dev_max[dev_max.size() - 1 - static_cast<size_t>(tensor_map[0])] * repeated_num;
444 }
445 if (tensor_shape_.array()[0] % split_num != 0) {
446 MS_LOG(INFO) << "Tensor could not be shard on the first dimension.";
447 return Status::FAILED;
448 }
449 Shape origin_slice_shape = slice_shape().array();
450 origin_slice_shape[0] = tensor_shape_.array()[0] / split_num;
451 opt_shard_slice_shape_ = origin_slice_shape;
452 return Status::SUCCESS;
453 }
454 } // namespace parallel
455 } // namespace mindspore
456