• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_CAME_PARALLEL_HANDLER_H_
17 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_CAME_PARALLEL_HANDLER_H_
18 
19 #include <string>
20 #include <vector>
21 
22 #include <set>
23 #include <utility>
24 #include <memory>
25 #include <unordered_map>
26 #include "base/base.h"
27 #include "frontend/parallel/device_manager.h"
28 #include "frontend/parallel/tensor_layout/tensor_layout.h"
29 
30 namespace mindspore {
31 namespace parallel {
32 using TensorLayoutPtr = std::shared_ptr<TensorLayout>;
33 
34 constexpr size_t kFirstCameReduceMean = 1;
35 constexpr size_t kSecondCameReduceMean = 2;
36 constexpr size_t kThirdCameReduceMean = 3;
37 constexpr size_t kForthCameReduceMean = 4;
38 constexpr size_t kFifthCameReduceMean = 5;
39 constexpr size_t kSixthCameReduceMean = 6;
40 constexpr size_t kSeventhCameReduceMean = 7;
41 constexpr size_t kParameterDimTwo = 2;
42 
43 constexpr char EXP_AVG[] = "exp_avg";
44 constexpr char EXP_AVG_SQ_ROW[] = "exp_avg_sq_row_";
45 constexpr char EXP_AVG_SQ_COL[] = "exp_avg_sq_col_";
46 constexpr char EXP_AVG_INSTA_ROW[] = "exp_avg_insta_row_";
47 constexpr char EXP_AVG_INSTA_COL[] = "exp_avg_insta_col_";
48 constexpr char EXP_AVG_SQ[] = "exp_avg_sq_";
49 
50 class CameCommHandler {
51  public:
52   CameCommHandler(ParameterPtr origin, const std::vector<AnfNodePtr> &all_parameters,
53                   const NodeUsersMap &node_user_map);
54   void Process();
55 
56  private:
57   ParameterPtr origin;
58   const std::vector<AnfNodePtr> &all_parameters;
59   TensorLayoutPtr tensor_layout;
60   const NodeUsersMap &node_user_map;
61 
62   int64_t cur_rank = -1;
63   DeviceMatrix dev_matrix;
64   RankList full_rank_list;
65 
66   bool is_opt_shard = false;
67 
68   ParameterPtr exp_avg_sq_row = nullptr;
69   ParameterPtr exp_avg_sq_col = nullptr;
70   ParameterPtr exp_avg = nullptr;
71   ParameterPtr exp_avg_insta_row = nullptr;
72   ParameterPtr exp_avg_insta_col = nullptr;
73 
74   std::set<size_t> reduce_mean_numbers = {kFirstCameReduceMean,  kSecondCameReduceMean, kThirdCameReduceMean,
75                                           kForthCameReduceMean,  kFifthCameReduceMean,  kSixthCameReduceMean,
76                                           kSeventhCameReduceMean};
77 
78   void FindCameParams();
79 
80   CNodePtr FindReduceMean(size_t number);
81   CNodePtr FindReduceMean1256(const ParameterPtr &param);
82   CNodePtr FindReduceMean37(const ParameterPtr &param);
83   CNodePtr FindReduceMean4();
84 
85   std::pair<Status, RankList> GetOptShardRankList(const int64_t rank);
86   std::pair<Status, RankList> GetDimRankList(const int64_t rank, const int64_t dim);
87 
88   RankList ExpandRankListWithOptShard(const RankList &rank_list);
89   RankList ExpandRankListWithDim(const RankList &base, const int64_t dim);
90 
91   std::string CreateCommGroupFromRankList(const RankList &rank_list);
92   void InsertAllReduceAndRealDivToReduceMeanInput(CNodePtr reduce_mean, const RankList &comm_rank_list);
93 };
94 }  // namespace parallel
95 }  // namespace mindspore
96 
97 #endif  // MINDSPORE_CCSRC_FRONTEND_PARALLEL_CAME_PARALLEL_HANDLER_H_
98