• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_CONTEXT_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_CONTEXT_H_
19 
20 #include <cstdint>
21 #include <map>
22 #include <memory>
23 #include <string>
24 #include <vector>
25 
26 #include "abstract/abstract_value.h"
27 #include "frontend/parallel/ops_info/ops_utils.h"
28 #include "frontend/parallel/status.h"
29 #include "ir/anf.h"
30 #include "ir/func_graph.h"
31 #include "utils/convert_utils.h"
32 #include "utils/info.h"
33 #include "pipeline/jit/pipeline.h"
34 
35 namespace mindspore {
36 namespace parallel {
37 constexpr char STAND_ALONE[] = "stand_alone";
38 constexpr char DATA_PARALLEL[] = "data_parallel";
39 constexpr char HYBRID_PARALLEL[] = "hybrid_parallel";
40 constexpr char AUTO_PARALLEL[] = "auto_parallel";
41 constexpr char SEMI_AUTO_PARALLEL[] = "semi_auto_parallel";
42 
43 constexpr char DYNAMIC_PROGRAMMING[] = "dynamic_programming";
44 constexpr char RECURSIVE_PROGRAMMING[] = "recursive_programming";
45 
46 constexpr char TRAINING[] = "training";
47 constexpr char ACCUMULATION[] = "accumulation";
48 
49 constexpr char ALL_GROUP_PARALLEL[] = "all_group_parallel";
50 constexpr char SAME_SERVER_GROUP_PARALLEL[] = "same_server_group_parallel";
51 constexpr char NO_GROUP_PARALLEL[] = "no_group_parallel";
52 
53 constexpr char IS_FIRST_ITERATION[] = "is_first_iteration";
54 class ParallelContext {
55  public:
56   ~ParallelContext() = default;
57   ParallelContext(const ParallelContext &) = delete;
58   ParallelContext &operator=(const ParallelContext &) = delete;
59 
60   static std::shared_ptr<ParallelContext> GetInstance();
61 
62   void set_gradients_mean(bool gradients_mean);
gradients_mean()63   bool gradients_mean() const { return gradients_mean_; }
64 
65   void set_full_batch(bool full_batch);
full_batch()66   bool full_batch() const { return full_batch_; }
67 
68   void set_dataset_strategy(const std::vector<std::vector<int64_t>> &dataset_strategy);
dataset_strategy()69   std::vector<std::vector<int64_t>> dataset_strategy() const { return dataset_strategy_; }
70 
71   void set_gradient_fp32_sync(bool gradient_fp32_sync);
gradient_fp32_sync()72   bool gradient_fp32_sync() const { return gradient_fp32_sync_; }
73 
74   void set_loss_repeated_mean(bool loss_repeated_mean);
loss_repeated_mean()75   bool loss_repeated_mean() const { return loss_repeated_mean_; }
76 
77   void set_device_num(int64_t device_num);
device_num()78   int64_t device_num() const { return device_num_; }
79 
80   void set_pipeline_stage_split_num(const int64_t stages);
pipeline_stage_split_num()81   int64_t pipeline_stage_split_num() const { return pipeline_stage_split_num_; }
82 
83   void set_global_rank(int64_t global_rank);
global_rank()84   int64_t global_rank() const { return global_rank_; }
85 
86   void set_grad_accumulation_step(int64_t grad_accumulation_step);
grad_accumulation_step()87   int64_t grad_accumulation_step() const { return grad_accumulation_step_; }
88 
89   bool set_parallel_mode(const std::string &parallel_mode);
parallel_mode()90   std::string parallel_mode() const { return parallel_mode_; }
91 
92   bool set_strategy_search_mode(const std::string &strategy_search_mode);
strategy_search_mode()93   std::string strategy_search_mode() const { return strategy_search_mode_; }
94 
95   void set_parameter_broadcast(bool parameter_broadcast);
parameter_broadcast()96   bool parameter_broadcast() const { return parameter_broadcast_; }
97 
device_num_is_set()98   bool device_num_is_set() const { return device_num_is_set_; }
global_rank_is_set()99   bool global_rank_is_set() const { return global_rank_is_set_; }
parameter_broadcast_is_set()100   bool parameter_broadcast_is_set() const { return parameter_broadcast_is_set_; }
101 
102   void set_optimizer_weight_shard_size(int64_t optimizer_weight_shard_size);
optimizer_weight_shard_size()103   int64_t optimizer_weight_shard_size() const { return optimizer_weight_shard_size_; }
104   void set_optimizer_weight_shard_aggregated_save(bool optimizer_weight_shard_aggregated_save);
optimizer_weight_shard_aggregated_save()105   bool optimizer_weight_shard_aggregated_save() const { return optimizer_weight_shard_aggregated_save_; }
106 
107   void SetAllReduceFusionSplitIndices(const std::vector<uint32_t> &indices, const std::string &group);
108   std::vector<uint32_t> GetAllReduceFusionSplitIndices(const std::string &group) const;
109   void SetAllReduceFusionSplitSizes(const std::vector<uint32_t> &sizes, const std::string &group);
110   std::vector<uint32_t> GetAllReduceFusionSplitSizes(const std::string &group) const;
set_enable_all_reduce_fusion(bool enable_all_reduce_fusion)111   void set_enable_all_reduce_fusion(bool enable_all_reduce_fusion) {
112     enable_all_reduce_fusion_ = enable_all_reduce_fusion;
113   }
enable_all_reduce_fusion()114   bool enable_all_reduce_fusion() const { return enable_all_reduce_fusion_; }
115 
116   void set_strategy_ckpt_load_file(const std::string &strategy_ckpt_load_file);
strategy_ckpt_load_file()117   std::string strategy_ckpt_load_file() const { return strategy_ckpt_load_file_; }
118   void set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file);
strategy_ckpt_save_file()119   std::string strategy_ckpt_save_file() const { return strategy_ckpt_save_file_; }
120   void set_group_ckpt_save_file(const std::string &group_ckpt_save_file);
group_ckpt_save_file()121   std::string group_ckpt_save_file() const { return group_ckpt_save_file_; }
122 
set_enable_parallel_optimizer(bool enable_parallel_optimizer)123   void set_enable_parallel_optimizer(bool enable_parallel_optimizer) {
124     enable_parallel_optimizer_ = enable_parallel_optimizer;
125   }
enable_parallel_optimizer()126   bool enable_parallel_optimizer() const { return enable_parallel_optimizer_; }
127 
128   bool set_communi_parallel_mode(const std::string &communi_parallel_mode);
communi_parallel_mode()129   std::string communi_parallel_mode() const { return communi_parallel_mode_; }
130   void set_sharding_propagation(const bool);
sharding_propagation()131   bool sharding_propagation() const { return sharding_propagation_; }
132   void set_enable_all2all(const bool);
enable_all2all()133   bool enable_all2all() const { return enable_all2all_; }
134 
135   void Reset();
136   void ParallelParameterContextInitShape(const FuncGraphPtr &func_graph);
137   void ParallelParameterContextRestoreShape(const FuncGraphPtr &func_graph, const ParameterPtr &param_node,
138                                             const AbstractBasePtr &ptr);
139   void ParallelParameterContextCkptShape(const FuncGraphPtr &func_graph, const ParameterPtr &param_node,
140                                          const AbstractBasePtr &ptr);
141 
142  private:
143   ParallelContext();
144   static std::shared_ptr<ParallelContext> inst_context_;
145   bool gradients_mean_;
146   bool full_batch_;
147   bool gradient_fp32_sync_;
148   bool loss_repeated_mean_;
149   int64_t device_num_;
150   int64_t global_rank_;
151   int64_t grad_accumulation_step_;
152   std::string parallel_mode_;
153   std::string strategy_search_mode_;
154   int64_t pipeline_stage_split_num_;
155   bool parameter_broadcast_;
156   bool device_num_is_set_;
157   bool global_rank_is_set_;
158   bool parameter_broadcast_is_set_;
159   bool enable_all_reduce_fusion_;
160   std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_indices_;
161   std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_sizes_;
162   std::string strategy_ckpt_load_file_;
163   std::string strategy_ckpt_save_file_;
164   std::string group_ckpt_save_file_;
165   bool enable_parallel_optimizer_;
166   bool init_param_shape_;
167   std::string communi_parallel_mode_;
168   int64_t optimizer_weight_shard_size_;
169   bool optimizer_weight_shard_aggregated_save_;
170   // In AUTO_PARALLEL mode, 'sharding_propagation_' = True indicates that sharding-configured operators
171   // will propagate the sharding strategies to other operators with minimum redistribution cost.
172   bool sharding_propagation_;
173   // Enable AllToAll or not. If false, use AllGather and Split.
174   bool enable_all2all_;
175   std::vector<std::vector<int64_t>> dataset_strategy_;
176 };
177 
178 }  // namespace parallel
179 }  // namespace mindspore
180 
181 #endif  // MINDSPORE_CCSRC_FRONTEND_PARALLEL_CONTEXT_H_
182