1 /** 2 * Copyright 2019-2024 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_TENSOR_LAYOUT_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_TENSOR_LAYOUT_H_ 19 20 #include <cstdint> 21 #include <map> 22 #include <memory> 23 #include <string> 24 #include <utility> 25 #include <vector> 26 #include <functional> 27 #include "frontend/parallel/device_manager.h" 28 #include "frontend/parallel/status.h" 29 #include "frontend/parallel/tensor_layout/arrangement.h" 30 #include "frontend/parallel/tensor_layout/map.h" 31 #include "include/common/utils/convert_utils.h" 32 33 namespace mindspore { 34 namespace parallel { 35 class TensorLayout { 36 public: 37 TensorLayout() = default; 38 ~TensorLayout() = default; 39 std::string ToString() const; 40 std::string StandardToString() const; 41 std::string OriginToString() const; 42 Status Init(const Arrangement &device_arrangement, const Map &tensor_map, const Arrangement &tensor_shape); 43 Status InitFromVector(const Shape &device_arrangement, const Shape &tensor_map, const Shape &tensor_shape); 44 Status InitFromExtendVector(const Shape &device_matrix, const std::vector<Shape> &tensor_map, 45 const Shape &tensor_shape, bool interleaved_parallel = false, 46 bool check_device_num = true); 47 UpdateTensorShape(size_t index,int64_t update_value)48 Status UpdateTensorShape(size_t index, int64_t update_value) { 49 return this->tensor_shape_.UpdateTensorShape(index, update_value); 50 } 51 skip_redistribution()52 bool skip_redistribution() const { return skip_redistribution_; } 53 set_skip_redistribution(bool flag)54 void set_skip_redistribution(bool flag) { skip_redistribution_ = flag; } 55 layout_transfer()56 bool layout_transfer() const { return layout_transfer_; } 57 set_layout_transfer(bool flag)58 void set_layout_transfer(bool flag) { layout_transfer_ = flag; } 59 get_field_size()60 int64_t get_field_size() const { return field_size_; } 61 set_field_size(int64_t field_size)62 void set_field_size(int64_t field_size) { field_size_ = field_size; } 63 uniform_split()64 bool uniform_split() const { return uniform_split_; } 65 set_uniform_split(bool flag)66 void set_uniform_split(bool flag) { uniform_split_ = flag; } 67 device_arrangement()68 Arrangement device_arrangement() const { return device_arrangement_; } 69 tensor_map()70 Map tensor_map() const { return tensor_map_; } 71 tensor_shape()72 Arrangement tensor_shape() const { return tensor_shape_; } 73 tensor_shape_origin()74 Arrangement tensor_shape_origin() const { return tensor_shape_origin_; } 75 device_arrangement_origin()76 Arrangement device_arrangement_origin() const { return device_arrangement_origin_; } 77 origin_tensor_map()78 Map origin_tensor_map() const { return tensor_map_origin_; } 79 80 std::shared_ptr<TensorLayout> ExpandTensorShape(const Arrangement &expanded_shape) const; 81 82 std::shared_ptr<TensorLayout> ExpandDeviceArrangement(const Arrangement &expanded_arrangement) const; 83 IsSameTensorShape(const TensorLayout & tensor_layout)84 bool IsSameTensorShape(const TensorLayout &tensor_layout) const { 85 return (tensor_shape_ == tensor_layout.tensor_shape()); 86 } 87 IsSameDeviceArrangement(const TensorLayout & tensor_layout)88 bool IsSameDeviceArrangement(const TensorLayout &tensor_layout) const { 89 return (device_arrangement_ == tensor_layout.device_arrangement()); 90 } 91 IsSameTensorMap(const TensorLayout & tensor_layout)92 bool IsSameTensorMap(const TensorLayout &tensor_layout) const { return (tensor_map_ == tensor_layout.tensor_map()); } 93 94 bool operator==(const TensorLayout &t1) const; 95 96 bool operator!=(const TensorLayout &t1) const; 97 98 bool IsSameWithoutSplit(const TensorLayout &t1) const; 99 100 bool IsInterleavedParallel() const; 101 102 bool TensorShapeCanBeExpanded(const Arrangement &expand_shape) const; 103 104 std::shared_ptr<Arrangement> ComputeExpandedTensorShape(const Arrangement &expand_shape) const; 105 106 Arrangement slice_shape() const; 107 108 Arrangement base_slice_shape() const; 109 110 Shape shard_strategy() const; 111 112 Status UpdateTensorMap(size_t index, int64_t value); 113 114 TensorLayout SqueezeShape() const; 115 116 TensorLayout TransferRepeatLayout() const; 117 118 Status GenerateOptShardSliceShape(); 119 opt_shard_slice_shape()120 Shape opt_shard_slice_shape() { return opt_shard_slice_shape_; } 121 set_opt_shard_slice_shape(Shape opt_slice_shape)122 void set_opt_shard_slice_shape(Shape opt_slice_shape) { opt_shard_slice_shape_ = std::move(opt_slice_shape); } 123 set_opt_shard_group(std::string name)124 void set_opt_shard_group(std::string name) { opt_shard_group_ = std::move(name); } 125 opt_shard_group()126 std::string opt_shard_group() const { return opt_shard_group_; } 127 set_opt_shard_mirror_group(std::string name)128 void set_opt_shard_mirror_group(std::string name) { opt_shard_mirror_group_ = std::move(name); } 129 opt_shard_mirror_group()130 std::string opt_shard_mirror_group() { return opt_shard_mirror_group_; } 131 set_opt_weight_shard_step(int32_t step)132 void set_opt_weight_shard_step(int32_t step) { opt_weight_shard_step_ = step; } 133 opt_weight_shard_step()134 int32_t opt_weight_shard_step() const { return opt_weight_shard_step_; } 135 set_opt_weight_shard_size(int32_t size)136 void set_opt_weight_shard_size(int32_t size) { opt_weight_shard_size_ = size; } 137 opt_weight_shard_size()138 int32_t opt_weight_shard_size() const { return opt_weight_shard_size_; } 139 set_is_shared_param(bool is_shared_param)140 void set_is_shared_param(bool is_shared_param) { is_shared_param_ = is_shared_param; } 141 is_shared_param()142 bool is_shared_param() const { return is_shared_param_; } 143 set_tensor_shape_before(const Shape & tensor_shape_before)144 void set_tensor_shape_before(const Shape &tensor_shape_before) { tensor_shape_before_.Init(tensor_shape_before); } 145 146 RankList InferRepeatedGroup(); 147 tensor_shape_before()148 Arrangement tensor_shape_before() const { return tensor_shape_before_; } 149 tensor_map_before()150 std::vector<Shape> tensor_map_before() const { return tensor_map_before_; } 151 152 int64_t GetSliceNumByTensorDimensionIndex(uint64_t idx) const; 153 154 TensorLayout LayoutForRedistribution() const; 155 156 std::vector<int64_t> GetVirtualRank() const; 157 device_arrangement_interleaved()158 Arrangement device_arrangement_interleaved() { return device_arrangement_interleaved_; } 159 set_device_arrangement_interleaved(Arrangement device_arrangement_interleaved)160 void set_device_arrangement_interleaved(Arrangement device_arrangement_interleaved) { 161 device_arrangement_interleaved_ = device_arrangement_interleaved; 162 } 163 // Key for user data. 164 constexpr static char key[] = "TLayout"; 165 166 private: 167 std::shared_ptr<TensorLayout> ExpandTensorShapeWithoutExtendDeviceArrangement( 168 const Arrangement &expanded_shape) const; 169 std::shared_ptr<Arrangement> ComputeArrangementByExpandedShape(const Arrangement &tensor_shape) const; 170 bool IsValidTensorLayout() const; 171 void RemoveElementEqualToOneInDeviceArrangement(); 172 int64_t GetSliceDeviceDimensionByTensorDimensionIndex(uint64_t idx) const; 173 bool TensorShapeDimensionIsDividedBySplitDeviceDimension() const; 174 int64_t GetTensorDimensionIndexByDeviceDimensionIndex(int64_t idx) const; 175 176 Arrangement device_arrangement_origin_; 177 Arrangement tensor_shape_origin_; 178 Arrangement device_arrangement_interleaved_; 179 Arrangement device_arrangement_; 180 Arrangement tensor_shape_; 181 Arrangement tensor_shape_before_; 182 Map tensor_map_; 183 Map tensor_map_origin_; 184 std::vector<Shape> tensor_map_before_; 185 bool skip_redistribution_ = false; 186 bool uniform_split_ = true; 187 bool layout_transfer_ = false; 188 int64_t field_size_ = 0; 189 Shape opt_shard_slice_shape_; 190 std::string opt_shard_group_ = ""; // for allgather 191 std::string opt_shard_mirror_group_ = ""; // for mirror ops 192 int32_t opt_weight_shard_step_ = 0; 193 int32_t opt_weight_shard_size_ = 0; 194 bool is_shared_param_ = false; 195 }; 196 } // namespace parallel 197 } // namespace mindspore 198 199 #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_TENSOR_LAYOUT_H_ 200