#include namespace { std::string get_detector_key( c10::DeviceType device_type, std::string connection_type) { std::ostringstream oss; oss << device_type << "/" << connection_type; return oss.str(); } class DetectorMap { public: static DetectorMap& get() { static DetectorMap instance; return instance; } void register_detector( c10::DeviceType device_type, const std::string& connection_type, c10::intrusive_ptr detector) { auto key = get_detector_key(device_type, connection_type); detector_map_[key] = std::move(detector); } c10::intrusive_ptr detect( c10::DeviceType device_type, const std::string& connection_type) { auto key = get_detector_key(device_type, connection_type); { auto it = cached_.find(key); if (it != cached_.end()) { return it->second; } } auto it = detector_map_.find(key); TORCH_CHECK( it != detector_map_.end(), "DMA connectivity detector for ", device_type, " over ", connection_type, " is not available"); auto detector = it->second; auto connectivity = detector->detect(); cached_[key] = connectivity; return connectivity; } private: DetectorMap() = default; DetectorMap(const DetectorMap&) = delete; DetectorMap& operator=(const DetectorMap&) = delete; std::unordered_map< std::string, c10::intrusive_ptr> detector_map_; std::unordered_map> cached_; }; }; // namespace namespace c10d { DMAConnectivity::DMAConnectivity( c10::DeviceType device_type, std::string connection_type, std::vector> matrix) : device_type(device_type), connection_type(connection_type), matrix(std::move(matrix)) {} void register_dma_connectivity_detector( c10::DeviceType device_type, const std::string& connection_type, c10::intrusive_ptr detector) { return DetectorMap::get().register_detector( device_type, connection_type, std::move(detector)); } c10::intrusive_ptr detect_dma_connectivity( c10::DeviceType device_type, const std::string& connection_type) { return DetectorMap::get().detect(device_type, connection_type); } } // namespace c10d