1 /**
2 * Copyright 2019-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/parallel_optimizer/opt_param_mgr.h"
18 #include <string>
19 #include <vector>
20 #include <functional>
21 #include <map>
22 #include <memory>
23 #include "frontend/parallel/ops_info/operator_info.h"
24 #include "include/common/utils/parallel_context.h"
25 #include "ir/dtype/type_id.h"
26
27 namespace mindspore {
28 namespace parallel {
29 class OptParamMgrImpl : public OptParamMgr {
30 public:
OptParamMgrImpl(const FuncGraphPtr & root)31 explicit OptParamMgrImpl(const FuncGraphPtr &root) : root_(root) {}
32 virtual ~OptParamMgrImpl() = default;
ShardOptGroup(const AnfNodePtr & parameter,TensorLayout * const tensor_layout,const OperatorInfoPtr & distribute_operator) const33 std::string ShardOptGroup(const AnfNodePtr ¶meter, TensorLayout *const tensor_layout,
34 const OperatorInfoPtr &distribute_operator) const override {
35 if (!SplitParam(parameter)) {
36 return "";
37 }
38
39 Status ret = tensor_layout->GenerateOptShardSliceShape();
40 if (ret != Status::SUCCESS) {
41 MS_LOG(INFO) << parameter->ToString() << "'s distributed shape " << tensor_layout->slice_shape().ToString()
42 << " does not satisfy the conditions.";
43 return "";
44 }
45 // get the shard tensor slice shape if the weight is repeated on devices
46 // and the shape of the first dimension could be divided
47 // apply parallel optimizer on parameters
48 // create communication group for allgather operator
49 std::string opt_shard_group;
50 std::vector<Group> dev_group;
51 MS_LOG(INFO) << "Creating shard group for param: " << parameter->ToString()
52 << ", shape: " << parameter->Shape()->ToString();
53 if (distribute_operator->CreateGroupForOptShard(tensor_layout, &dev_group) == Status::SUCCESS &&
54 !dev_group.empty()) {
55 opt_shard_group = dev_group[0].name();
56 MS_LOG(INFO) << "create group success.";
57 } else {
58 MS_LOG(WARNING) << "create opt shard group for the parameter " << parameter->ToString() << " failed.";
59 }
60 return opt_shard_group;
61 }
62
63 private:
ComputeShapeSize(const AnfNodePtr & parameter) const64 size_t ComputeShapeSize(const AnfNodePtr ¶meter) const {
65 ShapeVector shape(parameter->Shape()->cast<abstract::ShapePtr>()->shape());
66 size_t total_size = std::accumulate(shape.begin(), shape.end(), static_cast<size_t>(1), std::multiplies<size_t>());
67 return total_size;
68 }
69
70 // unit: B
ComputeMemorySize(const AnfNodePtr & parameter) const71 size_t ComputeMemorySize(const AnfNodePtr ¶meter) const {
72 // key, value: typeid, bytes
73 const std::map<TypeId, size_t> dtype_size_map = {
74 {kNumberTypeBool, sizeof(bool)}, {kNumberTypeInt4, sizeof(int8_t)},
75 {kNumberTypeInt8, sizeof(int8_t)}, {kNumberTypeInt16, sizeof(int16_t)},
76 {kNumberTypeInt32, sizeof(int32_t)}, {kNumberTypeInt64, sizeof(int64_t)},
77 {kNumberTypeFloat16, sizeof(float16)}, {kNumberTypeFloat32, sizeof(float)},
78 {kNumberTypeFloat64, sizeof(double)}, {kNumberTypeUInt8, sizeof(uint8_t)},
79 {kNumberTypeUInt16, sizeof(uint16_t)}, {kNumberTypeUInt32, sizeof(uint32_t)},
80 {kNumberTypeUInt64, sizeof(uint64_t)}, {kNumberTypeBFloat16, sizeof(bfloat16)}};
81
82 size_t shape_size = ComputeShapeSize(parameter);
83 TypeId type_id = parameter->Type()->cast<mindspore::TensorTypePtr>()->element()->type_id();
84 if (dtype_size_map.find(type_id) == dtype_size_map.end()) {
85 MS_LOG(EXCEPTION) << "unsupported type of parameter: " << parameter->DebugString();
86 }
87 size_t type_size = dtype_size_map.find(type_id)->second;
88 return shape_size * type_size;
89 }
90
GetThresholdFromUsrInput() const91 int64_t GetThresholdFromUsrInput() const {
92 return ParallelContext::GetInstance()->get_parallel_optimizer_threshold();
93 }
94
SplitParam(const AnfNodePtr & parameter) const95 bool SplitParam(const AnfNodePtr ¶meter) const {
96 if (!ParallelContext::GetInstance()->enable_parallel_optimizer()) {
97 MS_LOG(INFO) << "Parallel optimizer: feature is not enabled. Skipped.";
98 return false;
99 }
100
101 auto param_ptr = parameter->cast<ParameterPtr>();
102 if ((!param_ptr) || (!param_ptr->has_default())) {
103 MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " is not a parameter.";
104 return false;
105 }
106
107 if (parameter->cast<ParameterPtr>()->param_info() &&
108 !parameter->cast<ParameterPtr>()->param_info()->parallel_optimizer()) {
109 MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " is manually set skipped.";
110 return false;
111 }
112
113 size_t param_split_threshold = DEFAULT_VAL * KB_SIZE;
114 int64_t user_define_threshold = GetThresholdFromUsrInput();
115 if (user_define_threshold != -1) {
116 MS_LOG(INFO) << "Parallel optimizer: use user-define threshold = " << user_define_threshold << "KB.";
117 param_split_threshold = user_define_threshold * KB_SIZE;
118 } else {
119 MS_LOG(INFO) << "Parallel optimizer: use DEFAULT threshold = " << DEFAULT_VAL << "KB.";
120 }
121
122 size_t param_size = ComputeMemorySize(parameter);
123 MS_LOG(INFO) << "Parallel optimizer: " << parameter->ToString() << " size = " << param_size << "B";
124 if (param_size < param_split_threshold) {
125 MS_LOG(INFO) << "Parallel optimizer: the size of " << parameter->ToString() << "(" << param_size
126 << "KB) is smaller than the threshold(" << param_split_threshold << "B). Skipped.";
127 parameter->cast<ParameterPtr>()->param_info()->set_parallel_optimizer(false);
128 return false;
129 }
130 return true;
131 }
132
133 FuncGraphPtr root_;
134 size_t DEFAULT_VAL = 64; // unit: KB
135 size_t KB_SIZE = 1024;
136 };
137
createOptParamMgr(const FuncGraphPtr & root)138 std::unique_ptr<OptParamMgr> createOptParamMgr(const FuncGraphPtr &root) {
139 return std::make_unique<OptParamMgrImpl>(root);
140 }
141 } // namespace parallel
142 } // namespace mindspore
143