• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_TENSOR_LAYOUT_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_TENSOR_LAYOUT_H_
19 
20 #include <cstdint>
21 #include <map>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 #include <functional>
27 #include "frontend/parallel/device_manager.h"
28 #include "frontend/parallel/status.h"
29 #include "frontend/parallel/tensor_layout/arrangement.h"
30 #include "frontend/parallel/tensor_layout/map.h"
31 #include "utils/convert_utils.h"
32 
33 namespace mindspore {
34 namespace parallel {
35 class TensorLayout {
36  public:
37   TensorLayout() = default;
38   ~TensorLayout() = default;
39   std::string ToString() const;
40   std::string StandardToString() const;
41   std::string OriginToString() const;
42   Status Init(const Arrangement &device_arrangement, const Map &tensor_map, const Arrangement &tensor_shape);
43   Status InitFromVector(const Shape &device_arrangement, const Shape &tensor_map, const Shape &tensor_shape);
44 
skip_redistribution()45   bool skip_redistribution() const { return skip_redistribution_; }
46 
set_skip_redistribution(bool flag)47   void set_skip_redistribution(bool flag) { skip_redistribution_ = flag; }
48 
layout_transfer()49   bool layout_transfer() const { return layout_transfer_; }
50 
set_layout_transfer(bool flag)51   void set_layout_transfer(bool flag) { layout_transfer_ = flag; }
52 
get_field_size()53   int64_t get_field_size() const { return field_size_; }
54 
set_field_size(int64_t field_size)55   void set_field_size(int64_t field_size) { field_size_ = field_size; }
56 
uniform_split()57   bool uniform_split() const { return uniform_split_; }
58 
set_uniform_split(bool flag)59   void set_uniform_split(bool flag) { uniform_split_ = flag; }
60 
device_arrangement()61   Arrangement device_arrangement() const { return device_arrangement_; }
62 
tensor_map()63   Map tensor_map() const { return tensor_map_; }
64 
tensor_shape()65   Arrangement tensor_shape() const { return tensor_shape_; }
66 
origin_tensor_map()67   Map origin_tensor_map() const { return tensor_map_origin_; }
68 
69   std::shared_ptr<TensorLayout> ExpandTensorShape(const Arrangement &expanded_shape) const;
70 
71   std::shared_ptr<TensorLayout> ExpandDeviceArrangement(const Arrangement &expanded_arrangement) const;
72 
IsSameTensorShape(const TensorLayout & tensor_layout)73   bool IsSameTensorShape(const TensorLayout &tensor_layout) const {
74     return (tensor_shape_ == tensor_layout.tensor_shape());
75   }
76 
IsSameDeviceArrangement(const TensorLayout & tensor_layout)77   bool IsSameDeviceArrangement(const TensorLayout &tensor_layout) const {
78     return (device_arrangement_ == tensor_layout.device_arrangement());
79   }
80 
IsSameTensorMap(const TensorLayout & tensor_layout)81   bool IsSameTensorMap(const TensorLayout &tensor_layout) const { return (tensor_map_ == tensor_layout.tensor_map()); }
82 
83   bool operator==(const TensorLayout &t1) const;
84 
85   bool operator!=(const TensorLayout &t1) const;
86 
87   bool TensorShapeCanBeExpanded(const Arrangement &expanded_shape) const;
88 
89   std::shared_ptr<Arrangement> ComputeExpandedTensorShape(const Arrangement &expand_shape) const;
90 
91   Arrangement slice_shape() const;
92 
93   Status UpdateTensorMap(size_t index, int64_t value);
94 
95   TensorLayout SqueezeShape() const;
96 
97   TensorLayout TransferRepeatLayout() const;
98 
99   Status GenerateOptShardSliceShape();
100 
opt_shard_slice_shape()101   Shape opt_shard_slice_shape() { return opt_shard_slice_shape_; }
102 
set_opt_shard_group(std::string name)103   void set_opt_shard_group(std::string name) { opt_shard_group_ = std::move(name); }
104 
opt_shard_group()105   std::string opt_shard_group() { return opt_shard_group_; }
106 
set_opt_shard_mirror_group(std::string name)107   void set_opt_shard_mirror_group(std::string name) { opt_shard_mirror_group_ = std::move(name); }
108 
opt_shard_mirror_group()109   std::string opt_shard_mirror_group() { return opt_shard_mirror_group_; }
110 
set_opt_weight_shard_step(int32_t step)111   void set_opt_weight_shard_step(int32_t step) { opt_weight_shard_step_ = step; }
112 
opt_weight_shard_step()113   int32_t opt_weight_shard_step() { return opt_weight_shard_step_; }
114 
set_opt_weight_shard_size(int32_t size)115   void set_opt_weight_shard_size(int32_t size) { opt_weight_shard_size_ = size; }
116 
opt_weight_shard_size()117   int32_t opt_weight_shard_size() { return opt_weight_shard_size_; }
118 
set_is_shared_param(bool is_shared_param)119   void set_is_shared_param(bool is_shared_param) { is_shared_param_ = is_shared_param; }
120 
is_shared_param()121   bool is_shared_param() { return is_shared_param_; }
122 
123   // Key for user data.
124   constexpr static char key[] = "TLayout";
125 
126  private:
127   std::shared_ptr<TensorLayout> ExpandTensorShapeWithoutExtendDeviceArrangement(
128     const Arrangement &expanded_shape) const;
129   std::shared_ptr<Arrangement> ComputeArrangementByExpandedShape(const Arrangement &tensor_shape) const;
130   bool IsValidTensorLayout() const;
131   void RemoveElementEqualToOneInDeviceArrangement();
132   int64_t GetSliceDeviceDimensionByTensorDimensionIndex(uint64_t idx) const;
133   int64_t GetSliceNumByTensorDimensionIndex(uint64_t idx) const;
134   bool TensorShapeDimensionIsDividedBySplitDeviceDimension() const;
135   int64_t GetTensorDimensionIndexByDeviceDimensionIndex(int64_t idx) const;
136 
137   Arrangement device_arrangement_origin_;
138   Arrangement tensor_shape_origin_;
139   Arrangement device_arrangement_;
140   Arrangement tensor_shape_;
141   Map tensor_map_;
142   Map tensor_map_origin_;
143   bool skip_redistribution_ = false;
144   bool uniform_split_ = true;
145   bool layout_transfer_ = false;
146   int32_t field_size_ = 0;
147   Shape opt_shard_slice_shape_;
148   std::string opt_shard_group_ = "";         // for allgather
149   std::string opt_shard_mirror_group_ = "";  // for mirror ops
150   int32_t opt_weight_shard_step_ = 0;
151   int32_t opt_weight_shard_size_ = 0;
152   bool is_shared_param_ = false;
153 };
154 }  // namespace parallel
155 }  // namespace mindspore
156 
157 #endif  // MINDSPORE_CCSRC_FRONTEND_PARALLEL_TENSOR_LAYOUT_TENSOR_LAYOUT_H_
158