1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_MS_COLLECTIVE_COMM_LIB_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_MS_COLLECTIVE_COMM_LIB_H_ 19 20 #include <memory> 21 #include <vector> 22 #include <string> 23 #include <random> 24 #include "runtime/collective/collective_communication_lib.h" 25 #include "plugin/device/cpu/hal/hardware/ms_communication_group.h" 26 #include "include/backend/distributed/cluster/cluster_context.h" 27 #include "ps/core/collective_ops_impl.h" 28 #include "plugin/device/cpu/hal/hardware/ms_collective_node.h" 29 #include "plugin/device/cpu/hal/hardware/allreduce_impl.h" 30 #include "include/backend/distributed/cluster/topology/compute_graph_node.h" 31 32 namespace mindspore { 33 namespace device { 34 namespace cpu { 35 constexpr char kMCCLGlobalGroupName[] = "mccl_world_group"; 36 using ClusterContext = mindspore::distributed::cluster::ClusterContext; 37 using CollectiveOpsImpl = mindspore::fl::server::CollectiveOpsImpl; 38 using CommunicationGroupInfo = mindspore::fl::server::CommunicationGroupInfo; 39 using ps::core::NodeCommand; 40 41 // The time interval for send info or query info between worker and scheduler. 42 constexpr uint32_t kWaitDuration = 5; 43 44 // The retry number for MsCollectiveCommLib initializing. 45 constexpr uint32_t kMSCollectiveRetryTime = 200; 46 47 // The collective communication library for MindSpore self developed communication framework. 48 class MsCollectiveCommLib : public CollectiveCommunicationLib { 49 public: GetInstance()50 static MsCollectiveCommLib &GetInstance() { 51 static MsCollectiveCommLib instance; 52 return instance; 53 } 54 55 bool Initialize(uint32_t global_rank, uint32_t global_rank_size, uint32_t local_rank_id) override; 56 57 bool Finalize() override; 58 59 bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks, 60 uint32_t local_group_rank, uint32_t local_group_size) override; 61 62 bool AllGatherHostHashName(size_t host_hash_name, std::vector<size_t> *host_hash_names) override; 63 64 bool BroadcastUniqueID(const std::string &group_name, size_t root_info_size, void *root_info) override; 65 66 bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type, 67 const std::string &group_name, void *stream = nullptr) override; 68 69 bool AllReduce(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type, 70 CollectiveOpReduceType reduce_op, const std::string &group_name, void *stream = nullptr) override; 71 72 bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type, uint32_t root_rank, 73 const std::string &group_name, void *stream = nullptr) override; 74 75 bool ReduceScatter(const void *send_buff, void *recv_buff, size_t recv_count, TypeId data_type, 76 CollectiveOpReduceType reduce_op, const std::string &group_name, void *stream = nullptr) override { 77 return true; 78 } 79 80 private: 81 MsCollectiveCommLib(); 82 ~MsCollectiveCommLib() override = default; 83 84 // Send unique id to scheduler. 85 bool SendUniqueID(const std::string &group_name, size_t root_info_size, const void *root_info) const; 86 87 // Query unique id from scheduler. 88 bool QueryUniqueID(const std::string &group_name, size_t root_info_size, void *root_info); 89 90 std::shared_ptr<ps::core::CollectiveNode> node_; 91 92 // This compute graph node is maintained by the clusster context and used for metadata synchronization. 93 std::shared_ptr<distributed::cluster::topology::ComputeGraphNode> cgn_; 94 95 std::unique_ptr<AllReduceLauncher> launcher_; 96 97 // Indicates whether the collective node has to synchronize the addresses of all the collective nodes. 98 bool synchronized_{true}; 99 100 uint64_t retry_count_; 101 102 // Random retry interval. 103 std::uniform_int_distribution<> rand_distrib_; 104 }; 105 } // namespace cpu 106 } // namespace device 107 } // namespace mindspore 108 #endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_MS_COLLECTIVE_COMM_LIB_H_ 109