• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_TENSOR_LAYOUT_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_TENSOR_LAYOUT_H_
19 
20 #include <cstdint>
21 #include <map>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 #include <functional>
27 #include "frontend/parallel/device_manager.h"
28 #include "frontend/parallel/status.h"
29 #include "frontend/parallel/tensor_layout/arrangement.h"
30 #include "frontend/parallel/tensor_layout/map.h"
31 #include "include/common/utils/convert_utils.h"
32 
33 namespace mindspore {
34 namespace parallel {
35 class TensorLayout {
36  public:
37   TensorLayout() = default;
38   ~TensorLayout() = default;
39   std::string ToString() const;
40   std::string StandardToString() const;
41   std::string OriginToString() const;
42   Status Init(const Arrangement &device_arrangement, const Map &tensor_map, const Arrangement &tensor_shape);
43   Status InitFromVector(const Shape &device_arrangement, const Shape &tensor_map, const Shape &tensor_shape);
44   Status InitFromExtendVector(const Shape &device_matrix, const std::vector<Shape> &tensor_map,
45                               const Shape &tensor_shape, bool interleaved_parallel = false,
46                               bool check_device_num = true);
47 
UpdateTensorShape(size_t index,int64_t update_value)48   Status UpdateTensorShape(size_t index, int64_t update_value) {
49     return this->tensor_shape_.UpdateTensorShape(index, update_value);
50   }
51 
skip_redistribution()52   bool skip_redistribution() const { return skip_redistribution_; }
53 
set_skip_redistribution(bool flag)54   void set_skip_redistribution(bool flag) { skip_redistribution_ = flag; }
55 
layout_transfer()56   bool layout_transfer() const { return layout_transfer_; }
57 
set_layout_transfer(bool flag)58   void set_layout_transfer(bool flag) { layout_transfer_ = flag; }
59 
get_field_size()60   int64_t get_field_size() const { return field_size_; }
61 
set_field_size(int64_t field_size)62   void set_field_size(int64_t field_size) { field_size_ = field_size; }
63 
uniform_split()64   bool uniform_split() const { return uniform_split_; }
65 
set_uniform_split(bool flag)66   void set_uniform_split(bool flag) { uniform_split_ = flag; }
67 
device_arrangement()68   Arrangement device_arrangement() const { return device_arrangement_; }
69 
tensor_map()70   Map tensor_map() const { return tensor_map_; }
71 
tensor_shape()72   Arrangement tensor_shape() const { return tensor_shape_; }
73 
tensor_shape_origin()74   Arrangement tensor_shape_origin() const { return tensor_shape_origin_; }
75 
device_arrangement_origin()76   Arrangement device_arrangement_origin() const { return device_arrangement_origin_; }
77 
origin_tensor_map()78   Map origin_tensor_map() const { return tensor_map_origin_; }
79 
80   std::shared_ptr<TensorLayout> ExpandTensorShape(const Arrangement &expanded_shape) const;
81 
82   std::shared_ptr<TensorLayout> ExpandDeviceArrangement(const Arrangement &expanded_arrangement) const;
83 
IsSameTensorShape(const TensorLayout & tensor_layout)84   bool IsSameTensorShape(const TensorLayout &tensor_layout) const {
85     return (tensor_shape_ == tensor_layout.tensor_shape());
86   }
87 
IsSameDeviceArrangement(const TensorLayout & tensor_layout)88   bool IsSameDeviceArrangement(const TensorLayout &tensor_layout) const {
89     return (device_arrangement_ == tensor_layout.device_arrangement());
90   }
91 
IsSameTensorMap(const TensorLayout & tensor_layout)92   bool IsSameTensorMap(const TensorLayout &tensor_layout) const { return (tensor_map_ == tensor_layout.tensor_map()); }
93 
94   bool operator==(const TensorLayout &t1) const;
95 
96   bool operator!=(const TensorLayout &t1) const;
97 
98   bool IsSameWithoutSplit(const TensorLayout &t1) const;
99 
100   bool IsInterleavedParallel() const;
101 
102   bool TensorShapeCanBeExpanded(const Arrangement &expand_shape) const;
103 
104   std::shared_ptr<Arrangement> ComputeExpandedTensorShape(const Arrangement &expand_shape) const;
105 
106   Arrangement slice_shape() const;
107 
108   Arrangement base_slice_shape() const;
109 
110   Shape shard_strategy() const;
111 
112   Status UpdateTensorMap(size_t index, int64_t value);
113 
114   TensorLayout SqueezeShape() const;
115 
116   TensorLayout TransferRepeatLayout() const;
117 
118   Status GenerateOptShardSliceShape();
119 
opt_shard_slice_shape()120   Shape opt_shard_slice_shape() { return opt_shard_slice_shape_; }
121 
set_opt_shard_slice_shape(Shape opt_slice_shape)122   void set_opt_shard_slice_shape(Shape opt_slice_shape) { opt_shard_slice_shape_ = std::move(opt_slice_shape); }
123 
set_opt_shard_group(std::string name)124   void set_opt_shard_group(std::string name) { opt_shard_group_ = std::move(name); }
125 
opt_shard_group()126   std::string opt_shard_group() const { return opt_shard_group_; }
127 
set_opt_shard_mirror_group(std::string name)128   void set_opt_shard_mirror_group(std::string name) { opt_shard_mirror_group_ = std::move(name); }
129 
opt_shard_mirror_group()130   std::string opt_shard_mirror_group() { return opt_shard_mirror_group_; }
131 
set_opt_weight_shard_step(int32_t step)132   void set_opt_weight_shard_step(int32_t step) { opt_weight_shard_step_ = step; }
133 
opt_weight_shard_step()134   int32_t opt_weight_shard_step() const { return opt_weight_shard_step_; }
135 
set_opt_weight_shard_size(int32_t size)136   void set_opt_weight_shard_size(int32_t size) { opt_weight_shard_size_ = size; }
137 
opt_weight_shard_size()138   int32_t opt_weight_shard_size() const { return opt_weight_shard_size_; }
139 
set_is_shared_param(bool is_shared_param)140   void set_is_shared_param(bool is_shared_param) { is_shared_param_ = is_shared_param; }
141 
is_shared_param()142   bool is_shared_param() const { return is_shared_param_; }
143 
set_tensor_shape_before(const Shape & tensor_shape_before)144   void set_tensor_shape_before(const Shape &tensor_shape_before) { tensor_shape_before_.Init(tensor_shape_before); }
145 
146   RankList InferRepeatedGroup();
147 
tensor_shape_before()148   Arrangement tensor_shape_before() const { return tensor_shape_before_; }
149 
tensor_map_before()150   std::vector<Shape> tensor_map_before() const { return tensor_map_before_; }
151 
152   int64_t GetSliceNumByTensorDimensionIndex(uint64_t idx) const;
153 
154   TensorLayout LayoutForRedistribution() const;
155 
156   std::vector<int64_t> GetVirtualRank() const;
157 
device_arrangement_interleaved()158   Arrangement device_arrangement_interleaved() { return device_arrangement_interleaved_; }
159 
set_device_arrangement_interleaved(Arrangement device_arrangement_interleaved)160   void set_device_arrangement_interleaved(Arrangement device_arrangement_interleaved) {
161     device_arrangement_interleaved_ = device_arrangement_interleaved;
162   }
163   // Key for user data.
164   constexpr static char key[] = "TLayout";
165 
166  private:
167   std::shared_ptr<TensorLayout> ExpandTensorShapeWithoutExtendDeviceArrangement(
168     const Arrangement &expanded_shape) const;
169   std::shared_ptr<Arrangement> ComputeArrangementByExpandedShape(const Arrangement &tensor_shape) const;
170   bool IsValidTensorLayout() const;
171   void RemoveElementEqualToOneInDeviceArrangement();
172   int64_t GetSliceDeviceDimensionByTensorDimensionIndex(uint64_t idx) const;
173   bool TensorShapeDimensionIsDividedBySplitDeviceDimension() const;
174   int64_t GetTensorDimensionIndexByDeviceDimensionIndex(int64_t idx) const;
175 
176   Arrangement device_arrangement_origin_;
177   Arrangement tensor_shape_origin_;
178   Arrangement device_arrangement_interleaved_;
179   Arrangement device_arrangement_;
180   Arrangement tensor_shape_;
181   Arrangement tensor_shape_before_;
182   Map tensor_map_;
183   Map tensor_map_origin_;
184   std::vector<Shape> tensor_map_before_;
185   bool skip_redistribution_ = false;
186   bool uniform_split_ = true;
187   bool layout_transfer_ = false;
188   int64_t field_size_ = 0;
189   Shape opt_shard_slice_shape_;
190   std::string opt_shard_group_ = "";         // for allgather
191   std::string opt_shard_mirror_group_ = "";  // for mirror ops
192   int32_t opt_weight_shard_step_ = 0;
193   int32_t opt_weight_shard_size_ = 0;
194   bool is_shared_param_ = false;
195 };
196 }  // namespace parallel
197 }  // namespace mindspore
198 
199 #endif  // MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_TENSOR_LAYOUT_H_
200