• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_DEVICE_MANAGER_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_DEVICE_MANAGER_H_
19 
20 #include <cstdint>
21 #include <cstring>
22 #include <map>
23 #include <memory>
24 #include <string>
25 #include <utility>
26 #include <vector>
27 #include "frontend/parallel/device.h"
28 #include "frontend/parallel/device_matrix.h"
29 #include "frontend/parallel/group_manager.h"
30 #include "frontend/parallel/status.h"
31 #include "frontend/parallel/strategy.h"
32 #include "include/common/utils/convert_utils.h"
33 #include "utils/ms_utils.h"
34 
35 namespace mindspore {
36 namespace parallel {
37 constexpr int64_t MAX_DEVICE_NUM = 4294967295;
38 constexpr size_t DEVICE_NUM_PER_SERVER = 8;
39 constexpr char HCCL_BACKEND[] = "hccl";
40 constexpr char NCCL_BACKEND[] = "nccl";
41 constexpr char UNDEFINED_BACKEND[] = "undefined_backend";
42 
43 class DeviceManager;
44 using DeviceManagerPtr = std::shared_ptr<DeviceManager>;
45 // 'g_device_manager' is the globally unique manager to manage the devices.
46 extern DeviceManagerPtr g_device_manager;
47 
48 // This method is used for initializing the global DeviceManager 'g_device_manager',
49 // arguments including 'device_num' and 'global_rank'
50 bool InitDevice(int64_t device_num, int64_t global_rank, const std::string &backend, const std::vector<int64_t> &stage);
51 
52 void CheckGlobalDeviceManager();
53 
54 std::string HashName(const std::string &origin_name);
55 
56 class DeviceManager {
57   // This class is used to manage the abstract devices, including group-related and stage-related management.
58  public:
DeviceManager()59   DeviceManager() { gm_ = GroupManager(); }
60   ~DeviceManager() = default;
61 
62   Status Init(const RankList &devices, int64_t global_device_rank, const RankList &stage_map,
63               const std::string &backend);
64 
65   static DeviceManager &GetInstance();
66   RankList GetDeviceListByStageId(int64_t stage_id) const;
67   RankList GetDeviceListInThisStage() const;
68   RankList GetDeviceListBetweenStage() const;
69 
70   Device CreateNewDeviceByRank(int64_t rank) const;
71   std::vector<Device> CreateDeviceListByRankList(RankList ranks) const;
72   std::string GenerateGroupNameByRanks(RankList ranks);
73   Status CreateGroup(const std::string &group_name, const std::vector<Device> &devices, Group *const comm_group);
74   Status CreateGroup(const RankList &dev_ranks, Group *const comm_group);
75 
DeviceNum()76   size_t DeviceNum() const { return devices_.size(); }
stage_num()77   int64_t stage_num() const { return stage_num_; }
stage_device_num()78   int64_t stage_device_num() const { return stage_device_num_; }
stage_id()79   int64_t stage_id() const { return stage_id_; }
rank_index_in_stage()80   int64_t rank_index_in_stage() const { return rank_index_in_stage_; }
global_rank()81   int64_t global_rank() const { return global_rank_; }
backend()82   std::string backend() const { return backend_; }
group_manager()83   GroupManager group_manager() const { return gm_; }
set_group_manager(const GroupManager & gm)84   void set_group_manager(const GroupManager &gm) { gm_ = gm; }
85 
stage_devices()86   std::vector<std::vector<int64_t>> stage_devices() const { return stage_devices_; }
87 
88   void Clear();
world_group()89   std::string world_group() const { return gm_.world_group(); }
group_info()90   std::vector<std::pair<std::string, std::vector<uint32_t>>> group_info() const { return gm_.group_info(); }
91   std::string FindRankListNameByHashName(const std::string &hash_name);
92   RankList FindRankListByHashName(const std::string &hash_name);
93   Status CheckDeviceList(const RankList &rank_list) const;
94 
95  private:
96   std::vector<std::shared_ptr<Device>> devices_;
97   // each stage has a list of devices
98   std::vector<std::vector<int64_t>> stage_devices_;
99   std::shared_ptr<Device> device_;
100   GroupManager gm_;
101   std::string backend_;
102 
103   // bimap:
104   std::map<std::string, std::string> rank_to_group_;  // the key is rank list, value is hash name
105   std::map<std::string, std::string> group_to_rank_;  // the key is hash name, value is rank list
106 
107   int64_t global_rank_ = 0;          // the real rank in all devices
108   int64_t stage_num_ = 1;            // the stage num
109   int64_t stage_id_ = 0;             // the stage id of the global_rank_
110   int64_t rank_index_in_stage_ = 0;  // the index of this rank in it's stage
111   int64_t stage_device_num_ = 0;     // the device num of one stage
112 };
113 }  // namespace parallel
114 }  // namespace mindspore
115 
116 #endif  // MINDSPORE_CCSRC_FRONTEND_PARALLEL_DEVICE_MANAGER_H_
117