1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_DEVICE_MANAGER_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_DEVICE_MANAGER_H_ 19 20 #include <cstdint> 21 #include <cstring> 22 #include <map> 23 #include <memory> 24 #include <string> 25 #include <utility> 26 #include <vector> 27 #include "frontend/parallel/device.h" 28 #include "frontend/parallel/device_matrix.h" 29 #include "frontend/parallel/group_manager.h" 30 #include "frontend/parallel/status.h" 31 #include "frontend/parallel/strategy.h" 32 #include "include/common/utils/convert_utils.h" 33 #include "utils/ms_utils.h" 34 35 namespace mindspore { 36 namespace parallel { 37 constexpr int64_t MAX_DEVICE_NUM = 4294967295; 38 constexpr size_t DEVICE_NUM_PER_SERVER = 8; 39 constexpr char HCCL_BACKEND[] = "hccl"; 40 constexpr char NCCL_BACKEND[] = "nccl"; 41 constexpr char UNDEFINED_BACKEND[] = "undefined_backend"; 42 43 class DeviceManager; 44 using DeviceManagerPtr = std::shared_ptr<DeviceManager>; 45 // 'g_device_manager' is the globally unique manager to manage the devices. 46 extern DeviceManagerPtr g_device_manager; 47 48 // This method is used for initializing the global DeviceManager 'g_device_manager', 49 // arguments including 'device_num' and 'global_rank' 50 bool InitDevice(int64_t device_num, int64_t global_rank, const std::string &backend, const std::vector<int64_t> &stage); 51 52 void CheckGlobalDeviceManager(); 53 54 std::string HashName(const std::string &origin_name); 55 56 class DeviceManager { 57 // This class is used to manage the abstract devices, including group-related and stage-related management. 58 public: DeviceManager()59 DeviceManager() { gm_ = GroupManager(); } 60 ~DeviceManager() = default; 61 62 Status Init(const RankList &devices, int64_t global_device_rank, const RankList &stage_map, 63 const std::string &backend); 64 65 static DeviceManager &GetInstance(); 66 RankList GetDeviceListByStageId(int64_t stage_id) const; 67 RankList GetDeviceListInThisStage() const; 68 RankList GetDeviceListBetweenStage() const; 69 70 Device CreateNewDeviceByRank(int64_t rank) const; 71 std::vector<Device> CreateDeviceListByRankList(RankList ranks) const; 72 std::string GenerateGroupNameByRanks(RankList ranks); 73 Status CreateGroup(const std::string &group_name, const std::vector<Device> &devices, Group *const comm_group); 74 Status CreateGroup(const RankList &dev_ranks, Group *const comm_group); 75 DeviceNum()76 size_t DeviceNum() const { return devices_.size(); } stage_num()77 int64_t stage_num() const { return stage_num_; } stage_device_num()78 int64_t stage_device_num() const { return stage_device_num_; } stage_id()79 int64_t stage_id() const { return stage_id_; } rank_index_in_stage()80 int64_t rank_index_in_stage() const { return rank_index_in_stage_; } global_rank()81 int64_t global_rank() const { return global_rank_; } backend()82 std::string backend() const { return backend_; } group_manager()83 GroupManager group_manager() const { return gm_; } set_group_manager(const GroupManager & gm)84 void set_group_manager(const GroupManager &gm) { gm_ = gm; } 85 stage_devices()86 std::vector<std::vector<int64_t>> stage_devices() const { return stage_devices_; } 87 88 void Clear(); world_group()89 std::string world_group() const { return gm_.world_group(); } group_info()90 std::vector<std::pair<std::string, std::vector<uint32_t>>> group_info() const { return gm_.group_info(); } 91 std::string FindRankListNameByHashName(const std::string &hash_name); 92 RankList FindRankListByHashName(const std::string &hash_name); 93 Status CheckDeviceList(const RankList &rank_list) const; 94 95 private: 96 std::vector<std::shared_ptr<Device>> devices_; 97 // each stage has a list of devices 98 std::vector<std::vector<int64_t>> stage_devices_; 99 std::shared_ptr<Device> device_; 100 GroupManager gm_; 101 std::string backend_; 102 103 // bimap: 104 std::map<std::string, std::string> rank_to_group_; // the key is rank list, value is hash name 105 std::map<std::string, std::string> group_to_rank_; // the key is hash name, value is rank list 106 107 int64_t global_rank_ = 0; // the real rank in all devices 108 int64_t stage_num_ = 1; // the stage num 109 int64_t stage_id_ = 0; // the stage id of the global_rank_ 110 int64_t rank_index_in_stage_ = 0; // the index of this rank in it's stage 111 int64_t stage_device_num_ = 0; // the device num of one stage 112 }; 113 } // namespace parallel 114 } // namespace mindspore 115 116 #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_DEVICE_MANAGER_H_ 117