• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_MS_COLLECTIVE_COMM_LIB_H_
18 #define MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_MS_COLLECTIVE_COMM_LIB_H_
19 
20 #include <memory>
21 #include <vector>
22 #include <string>
23 #include <random>
24 #include "runtime/collective/collective_communication_lib.h"
25 #include "plugin/device/cpu/hal/hardware/ms_communication_group.h"
26 #include "include/backend/distributed/cluster/cluster_context.h"
27 #include "ps/core/collective_ops_impl.h"
28 #include "plugin/device/cpu/hal/hardware/ms_collective_node.h"
29 #include "plugin/device/cpu/hal/hardware/allreduce_impl.h"
30 #include "include/backend/distributed/cluster/topology/compute_graph_node.h"
31 
32 namespace mindspore {
33 namespace device {
34 namespace cpu {
35 constexpr char kMCCLGlobalGroupName[] = "mccl_world_group";
36 using ClusterContext = mindspore::distributed::cluster::ClusterContext;
37 using CollectiveOpsImpl = mindspore::fl::server::CollectiveOpsImpl;
38 using CommunicationGroupInfo = mindspore::fl::server::CommunicationGroupInfo;
39 using ps::core::NodeCommand;
40 
41 // The time interval for send info or query info between worker and scheduler.
42 constexpr uint32_t kWaitDuration = 5;
43 
44 // The retry number for MsCollectiveCommLib initializing.
45 constexpr uint32_t kMSCollectiveRetryTime = 200;
46 
47 // The collective communication library for MindSpore self developed communication framework.
48 class MsCollectiveCommLib : public CollectiveCommunicationLib {
49  public:
GetInstance()50   static MsCollectiveCommLib &GetInstance() {
51     static MsCollectiveCommLib instance;
52     return instance;
53   }
54 
55   bool Initialize(uint32_t global_rank, uint32_t global_rank_size, uint32_t local_rank_id) override;
56 
57   bool Finalize() override;
58 
59   bool CreateCommunicationGroup(const std::string &group_name, const std::vector<uint32_t> &group_ranks,
60                                 uint32_t local_group_rank, uint32_t local_group_size) override;
61 
62   bool AllGatherHostHashName(size_t host_hash_name, std::vector<size_t> *host_hash_names) override;
63 
64   bool BroadcastUniqueID(const std::string &group_name, size_t root_info_size, void *root_info) override;
65 
66   bool AllGather(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
67                  const std::string &group_name, void *stream = nullptr) override;
68 
69   bool AllReduce(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type,
70                  CollectiveOpReduceType reduce_op, const std::string &group_name, void *stream = nullptr) override;
71 
72   bool Broadcast(const void *send_buff, void *recv_buff, size_t send_count, TypeId data_type, uint32_t root_rank,
73                  const std::string &group_name, void *stream = nullptr) override;
74 
75   bool ReduceScatter(const void *send_buff, void *recv_buff, size_t recv_count, TypeId data_type,
76                      CollectiveOpReduceType reduce_op, const std::string &group_name, void *stream = nullptr) override {
77     return true;
78   }
79 
80  private:
81   MsCollectiveCommLib();
82   ~MsCollectiveCommLib() override = default;
83 
84   // Send unique id to scheduler.
85   bool SendUniqueID(const std::string &group_name, size_t root_info_size, const void *root_info) const;
86 
87   // Query unique id from scheduler.
88   bool QueryUniqueID(const std::string &group_name, size_t root_info_size, void *root_info);
89 
90   std::shared_ptr<ps::core::CollectiveNode> node_;
91 
92   // This compute graph node is maintained by the clusster context and used for metadata synchronization.
93   std::shared_ptr<distributed::cluster::topology::ComputeGraphNode> cgn_;
94 
95   std::unique_ptr<AllReduceLauncher> launcher_;
96 
97   // Indicates whether the collective node has to synchronize the addresses of all the collective nodes.
98   bool synchronized_{true};
99 
100   uint64_t retry_count_;
101 
102   // Random retry interval.
103   std::uniform_int_distribution<> rand_distrib_;
104 };
105 }  // namespace cpu
106 }  // namespace device
107 }  // namespace mindspore
108 #endif  // MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_MS_COLLECTIVE_COMM_LIB_H_
109