1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_TENSOR_LAYOUT_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_TENSOR_LAYOUT_H_ 19 20 #include <cstdint> 21 #include <map> 22 #include <memory> 23 #include <string> 24 #include <utility> 25 #include <vector> 26 #include <functional> 27 #include "frontend/parallel/device_manager.h" 28 #include "frontend/parallel/status.h" 29 #include "frontend/parallel/tensor_layout/arrangement.h" 30 #include "frontend/parallel/tensor_layout/map.h" 31 #include "utils/convert_utils.h" 32 33 namespace mindspore { 34 namespace parallel { 35 class TensorLayout { 36 public: 37 TensorLayout() = default; 38 ~TensorLayout() = default; 39 std::string ToString() const; 40 std::string StandardToString() const; 41 std::string OriginToString() const; 42 Status Init(const Arrangement &device_arrangement, const Map &tensor_map, const Arrangement &tensor_shape); 43 Status InitFromVector(const Shape &device_arrangement, const Shape &tensor_map, const Shape &tensor_shape); 44 skip_redistribution()45 bool skip_redistribution() const { return skip_redistribution_; } 46 set_skip_redistribution(bool flag)47 void set_skip_redistribution(bool flag) { skip_redistribution_ = flag; } 48 layout_transfer()49 bool layout_transfer() const { return layout_transfer_; } 50 set_layout_transfer(bool flag)51 void set_layout_transfer(bool flag) { layout_transfer_ = flag; } 52 get_field_size()53 int64_t get_field_size() const { return field_size_; } 54 set_field_size(int64_t field_size)55 void set_field_size(int64_t field_size) { field_size_ = field_size; } 56 uniform_split()57 bool uniform_split() const { return uniform_split_; } 58 set_uniform_split(bool flag)59 void set_uniform_split(bool flag) { uniform_split_ = flag; } 60 device_arrangement()61 Arrangement device_arrangement() const { return device_arrangement_; } 62 tensor_map()63 Map tensor_map() const { return tensor_map_; } 64 tensor_shape()65 Arrangement tensor_shape() const { return tensor_shape_; } 66 origin_tensor_map()67 Map origin_tensor_map() const { return tensor_map_origin_; } 68 69 std::shared_ptr<TensorLayout> ExpandTensorShape(const Arrangement &expanded_shape) const; 70 71 std::shared_ptr<TensorLayout> ExpandDeviceArrangement(const Arrangement &expanded_arrangement) const; 72 IsSameTensorShape(const TensorLayout & tensor_layout)73 bool IsSameTensorShape(const TensorLayout &tensor_layout) const { 74 return (tensor_shape_ == tensor_layout.tensor_shape()); 75 } 76 IsSameDeviceArrangement(const TensorLayout & tensor_layout)77 bool IsSameDeviceArrangement(const TensorLayout &tensor_layout) const { 78 return (device_arrangement_ == tensor_layout.device_arrangement()); 79 } 80 IsSameTensorMap(const TensorLayout & tensor_layout)81 bool IsSameTensorMap(const TensorLayout &tensor_layout) const { return (tensor_map_ == tensor_layout.tensor_map()); } 82 83 bool operator==(const TensorLayout &t1) const; 84 85 bool operator!=(const TensorLayout &t1) const; 86 87 bool TensorShapeCanBeExpanded(const Arrangement &expanded_shape) const; 88 89 std::shared_ptr<Arrangement> ComputeExpandedTensorShape(const Arrangement &expand_shape) const; 90 91 Arrangement slice_shape() const; 92 93 Status UpdateTensorMap(size_t index, int64_t value); 94 95 TensorLayout SqueezeShape() const; 96 97 TensorLayout TransferRepeatLayout() const; 98 99 Status GenerateOptShardSliceShape(); 100 opt_shard_slice_shape()101 Shape opt_shard_slice_shape() { return opt_shard_slice_shape_; } 102 set_opt_shard_group(std::string name)103 void set_opt_shard_group(std::string name) { opt_shard_group_ = std::move(name); } 104 opt_shard_group()105 std::string opt_shard_group() { return opt_shard_group_; } 106 set_opt_shard_mirror_group(std::string name)107 void set_opt_shard_mirror_group(std::string name) { opt_shard_mirror_group_ = std::move(name); } 108 opt_shard_mirror_group()109 std::string opt_shard_mirror_group() { return opt_shard_mirror_group_; } 110 set_opt_weight_shard_step(int32_t step)111 void set_opt_weight_shard_step(int32_t step) { opt_weight_shard_step_ = step; } 112 opt_weight_shard_step()113 int32_t opt_weight_shard_step() { return opt_weight_shard_step_; } 114 set_opt_weight_shard_size(int32_t size)115 void set_opt_weight_shard_size(int32_t size) { opt_weight_shard_size_ = size; } 116 opt_weight_shard_size()117 int32_t opt_weight_shard_size() { return opt_weight_shard_size_; } 118 set_is_shared_param(bool is_shared_param)119 void set_is_shared_param(bool is_shared_param) { is_shared_param_ = is_shared_param; } 120 is_shared_param()121 bool is_shared_param() { return is_shared_param_; } 122 123 // Key for user data. 124 constexpr static char key[] = "TLayout"; 125 126 private: 127 std::shared_ptr<TensorLayout> ExpandTensorShapeWithoutExtendDeviceArrangement( 128 const Arrangement &expanded_shape) const; 129 std::shared_ptr<Arrangement> ComputeArrangementByExpandedShape(const Arrangement &tensor_shape) const; 130 bool IsValidTensorLayout() const; 131 void RemoveElementEqualToOneInDeviceArrangement(); 132 int64_t GetSliceDeviceDimensionByTensorDimensionIndex(uint64_t idx) const; 133 int64_t GetSliceNumByTensorDimensionIndex(uint64_t idx) const; 134 bool TensorShapeDimensionIsDividedBySplitDeviceDimension() const; 135 int64_t GetTensorDimensionIndexByDeviceDimensionIndex(int64_t idx) const; 136 137 Arrangement device_arrangement_origin_; 138 Arrangement tensor_shape_origin_; 139 Arrangement device_arrangement_; 140 Arrangement tensor_shape_; 141 Map tensor_map_; 142 Map tensor_map_origin_; 143 bool skip_redistribution_ = false; 144 bool uniform_split_ = true; 145 bool layout_transfer_ = false; 146 int32_t field_size_ = 0; 147 Shape opt_shard_slice_shape_; 148 std::string opt_shard_group_ = ""; // for allgather 149 std::string opt_shard_mirror_group_ = ""; // for mirror ops 150 int32_t opt_weight_shard_step_ = 0; 151 int32_t opt_weight_shard_size_ = 0; 152 bool is_shared_param_ = false; 153 }; 154 } // namespace parallel 155 } // namespace mindspore 156 157 #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_TENSOR_LAYOUT_H_ 158