#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #ifdef _WIN32 #include #include #else #include #include #endif #ifdef _WIN32 #include #else #include #endif #include namespace c10d { namespace detail { // Manages the lifecycle of a server daemon. class TCPServer { public: static std::shared_ptr start(const TCPStoreOptions& opts); std::uint16_t port() const noexcept { return port_; } explicit TCPServer( std::uint16_t port, std::unique_ptr&& daemon) : port_{port}, daemon_{std::move(daemon)} {} std::string repr() const { return fmt::format("TCPServer(port={})", port_); } private: std::uint16_t port_; std::unique_ptr daemon_; // We store weak references to all TCPServers for which the caller requested // multi-tenancy. static std::unordered_map> cachedServers_; static std::mutex cache_mutex_; }; std::unordered_map> TCPServer::cachedServers_{}; std::mutex TCPServer::cache_mutex_{}; std::shared_ptr TCPServer::start(const TCPStoreOptions& opts) { auto startCore = [&opts]() { auto daemon = opts.useLibUV ? create_libuv_tcpstore_backend(opts) : create_tcpstore_backend(opts); daemon->start(); return std::make_shared(daemon->port(), std::move(daemon)); }; std::shared_ptr server{}; if (opts.multiTenant) { std::lock_guard guard{cache_mutex_}; // If the caller is okay with a multi-tenant store, first check if we // already have a TCPServer running on the specified port. if (opts.port > 0) { auto pos = cachedServers_.find(opts.port); if (pos != cachedServers_.end()) { server = pos->second.lock(); if (server != nullptr) { return server; } // Looks like the TCPStore has been disposed, make sure that we release // the control block. cachedServers_.erase(pos); } } server = startCore(); cachedServers_.emplace(server->port(), server); } else { server = startCore(); } return server; } class TCPClient { public: static std::unique_ptr connect( const SocketAddress& addr, const TCPStoreOptions& opts, std::shared_ptr backoff); void sendRaw(uint8_t* data, size_t length) { try { tcputil::sendBytes(socket_.handle(), data, length); } catch (const std::exception& e) { C10D_WARNING("sendBytes failed on {}: {}", socket_.repr(), e.what()); throw; } } std::vector receiveBits() { try { return tcputil::recvVector(socket_.handle()); } catch (const std::exception& e) { C10D_WARNING("recvVector failed on {}: {}", socket_.repr(), e.what()); throw; } } template T receiveValue() { try { return tcputil::recvValue(socket_.handle()); } catch (const std::exception& e) { C10D_WARNING("recvValue failed on {}: {}", socket_.repr(), e.what()); throw; } } template bool receiveValueWithTimeout(T& t, std::chrono::milliseconds timeout) { if (!socket_.waitForInput(timeout)) return false; t = tcputil::recvValue(socket_.handle()); return true; } void setTimeout(std::chrono::milliseconds value); explicit TCPClient(Socket&& socket) : socket_{std::move(socket)} {} std::string repr() const { return fmt::format("TCPClient({})", socket_.repr()); } private: Socket socket_; }; std::unique_ptr TCPClient::connect( const SocketAddress& addr, const TCPStoreOptions& opts, std::shared_ptr backoff) { Socket socket = Socket::connect( addr.host, addr.port, SocketOptions{} .connect_timeout(opts.timeout) .connect_backoff(std::move(backoff))); return std::make_unique(std::move(socket)); } void TCPClient::setTimeout(std::chrono::milliseconds value) { if (value == std::chrono::milliseconds::zero()) { return; } #ifdef _WIN32 struct timeval timeoutTV = { static_cast(value.count() / 1000), static_cast((value.count() % 1000) * 1000)}; #else struct timeval timeoutTV = { .tv_sec = value.count() / 1000, .tv_usec = static_cast((value.count() % 1000) * 1000), }; #endif SYSCHECK_ERR_RETURN_NEG1(::setsockopt( socket_.handle(), SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast(&timeoutTV), sizeof(timeoutTV))); } class SendBuffer { // ethernet mtu 1500 - 40 (ip v6 header) - 20 (tcp header) const size_t FLUSH_WATERMARK = 1440; std::vector buffer; detail::TCPClient& client; void maybeFlush() { if (buffer.size() >= FLUSH_WATERMARK) { flush(); } } public: SendBuffer(detail::TCPClient& client, detail::QueryType cmd) : client(client) { buffer.reserve(32); // enough for most commands buffer.push_back((uint8_t)cmd); } void appendString(const std::string& str) { appendValue(str.size()); buffer.insert(buffer.end(), str.begin(), str.end()); maybeFlush(); } void appendBytes(const std::vector& vec) { appendValue(vec.size()); buffer.insert(buffer.end(), vec.begin(), vec.end()); maybeFlush(); } template void appendValue(T value) { uint8_t* begin = (uint8_t*)&value; buffer.insert(buffer.end(), begin, begin + sizeof(T)); maybeFlush(); } void flush() { if (!buffer.empty()) { client.sendRaw(buffer.data(), buffer.size()); buffer.clear(); } } }; } // namespace detail using detail::Socket; // TCPStore class methods TCPStore::TCPStore( const std::string& masterAddr, std::uint16_t masterPort, std::optional numWorkers, bool isServer, const std::chrono::milliseconds& timeout, bool waitWorkers) : TCPStore{ masterAddr, TCPStoreOptions{ masterPort, isServer, numWorkers ? std::optional(*numWorkers) : std::nullopt, waitWorkers, timeout}} {} TCPStore::TCPStore(std::string host, const TCPStoreOptions& opts) : Store{opts.timeout}, addr_{std::move(host)}, numWorkers_{opts.numWorkers}, usingLibUv_{opts.useLibUV} { STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__init); if (opts.useLibUV) { TORCH_CHECK( ::c10d::detail::is_libuv_tcpstore_backend_available(), "use_libuv was requested but PyTorch was build without libuv support"); if (opts.masterListenFd.has_value()) { // TODO(xilunwu): support this init method after testing constexpr auto* msg = "The libuv TCPStore backend does not support initialization with an listen fd. " "Please switch to the legacy TCPStore by setting environment variable USE_LIBUV " "to \"0\"."; C10D_ERROR(msg); C10_THROW_ERROR(NotImplementedError, msg); return; } } Socket::initialize(); if (opts.isServer) { server_ = detail::TCPServer::start(opts); // server successfully started C10D_DEBUG("The server has started on port = {}.", server_->port()); std::ifstream maxconnFile("/proc/sys/net/core/somaxconn"); if (maxconnFile.good() && numWorkers_.has_value()) { try { std::string str( (std::istreambuf_iterator(maxconnFile)), std::istreambuf_iterator()); std::size_t somaxconn = std::stoll(str); if (somaxconn < *numWorkers_) { C10D_WARNING( "Starting store with {} workers but somaxconn is {}." "This might cause instability during bootstrap, consider increasing it.", *numWorkers_, somaxconn); } } catch (std::logic_error& e) { C10D_INFO("failed to parse somaxconn proc file due to {}", e.what()); } } addr_.port = server_->port(); } else { addr_.port = opts.port; } // Try connecting several times -- if the server listen backlog is full it may // fail on the first send in validate. auto deadline = std::chrono::steady_clock::now() + opts.timeout; auto backoff = std::make_shared(); auto retry = 0; do { try { client_ = detail::TCPClient::connect(addr_, opts, backoff); // TCP connection established C10D_DEBUG("TCP client connected to host {}:{}", addr_.host, addr_.port); // client's first query for validation validate(); // ping to verify network connectivity ping(); // success break; } catch (const c10::DistNetworkError& ex) { if (deadline < std::chrono::steady_clock::now()) { C10D_ERROR( "TCP client failed to connect/validate to host {}:{} - timed out (try={}, timeout={}ms): {}", addr_.host, addr_.port, retry, opts.timeout.count(), ex.what()); throw; } auto delayDuration = backoff->nextBackoff(); C10D_WARNING( "TCP client failed to connect/validate to host {}:{} - retrying (try={}, timeout={}ms, delay={}ms): {}", addr_.host, addr_.port, retry, opts.timeout.count(), delayDuration.count(), ex.what()); std::this_thread::sleep_for(delayDuration); retry += 1; } } while (true); if (opts.waitWorkers) { waitForWorkers(); } } TCPStore::~TCPStore() = default; void TCPStore::waitForWorkers() { STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__waitForWorkers); if (numWorkers_ == std::nullopt) { return; } incrementValueBy(initKey_, 1); // Let server block until all workers have completed, this ensures that // the server daemon thread is always running until the very end if (server_) { const auto start = std::chrono::steady_clock::now(); while (true) { // TODO: Any chance to make this cleaner? std::vector value = doGet(initKey_); auto buf = reinterpret_cast(value.data()); auto len = value.size(); int numWorkersCompleted = std::stoi(std::string(buf, len)); if (numWorkersCompleted >= static_cast(*numWorkers_)) { break; } const auto elapsed = std::chrono::duration_cast( std::chrono::steady_clock::now() - start); if (timeout_ != kNoTimeout && elapsed > timeout_) { C10_THROW_ERROR( DistStoreError, fmt::format( "Timed out after {} seconds waiting for clients. {}/{} clients joined.", elapsed.count(), numWorkersCompleted, *numWorkers_)); } /* sleep override */ std::this_thread::sleep_for(std::chrono::milliseconds(10)); } } } void TCPStore::validate() { const std::lock_guard lock(activeOpLock_); detail::SendBuffer buffer(*client_, detail::QueryType::VALIDATE); buffer.appendValue(c10d::detail::validationMagicNumber); buffer.flush(); } void TCPStore::ping() { const std::lock_guard lock(activeOpLock_); detail::SendBuffer buffer(*client_, detail::QueryType::PING); uint32_t nonce = getpid(); buffer.appendValue(nonce); buffer.flush(); uint32_t returnedNonce = client_->receiveValue(); TORCH_INTERNAL_ASSERT( nonce == returnedNonce, "Ping failed, invalid nonce returned"); } void TCPStore::_splitSet( const std::string& key, const std::vector& data) { const std::lock_guard lock(activeOpLock_); detail::SendBuffer buffer(*client_, detail::QueryType::SET); buffer.appendString(keyPrefix_ + key); buffer.flush(); std::this_thread::sleep_for(std::chrono::milliseconds(1000)); buffer.appendBytes(data); buffer.flush(); } void TCPStore::set(const std::string& key, const std::vector& data) { STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__set); const std::lock_guard lock(activeOpLock_); detail::SendBuffer buffer(*client_, detail::QueryType::SET); buffer.appendString(keyPrefix_ + key); buffer.appendBytes(data); buffer.flush(); } std::vector TCPStore::compareSet( const std::string& key, const std::vector& expectedValue, const std::vector& desiredValue) { STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__compareSet); const std::lock_guard lock(activeOpLock_); detail::SendBuffer buffer(*client_, detail::QueryType::COMPARE_SET); buffer.appendString(keyPrefix_ + key); buffer.appendBytes(expectedValue); buffer.appendBytes(desiredValue); buffer.flush(); return client_->receiveBits(); } std::vector TCPStore::get(const std::string& key) { STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__get); const std::lock_guard lock(activeOpLock_); return doGet(keyPrefix_ + key); } std::vector TCPStore::doGet(const std::string& key) { doWait(key, timeout_); detail::SendBuffer buffer(*client_, detail::QueryType::GET); buffer.appendString(key); buffer.flush(); return client_->receiveBits(); } int64_t TCPStore::add(const std::string& key, int64_t value) { STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__add); const std::lock_guard lock(activeOpLock_); return incrementValueBy(keyPrefix_ + key, value); } bool TCPStore::deleteKey(const std::string& key) { STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__delete); const std::lock_guard lock(activeOpLock_); detail::SendBuffer buffer(*client_, detail::QueryType::DELETE_KEY); buffer.appendString(keyPrefix_ + key); buffer.flush(); auto numDeleted = client_->receiveValue(); return numDeleted == 1; } int64_t TCPStore::incrementValueBy(const std::string& key, int64_t delta) { detail::SendBuffer buff(*client_, detail::QueryType::ADD); buff.appendString(key); buff.appendValue(delta); buff.flush(); return client_->receiveValue(); } int64_t TCPStore::getNumKeys() { const std::lock_guard lock(activeOpLock_); detail::SendBuffer buffer(*client_, detail::QueryType::GETNUMKEYS); buffer.flush(); return client_->receiveValue(); } bool TCPStore::check(const std::vector& keys) { STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__check); const std::lock_guard lock(activeOpLock_); detail::SendBuffer buffer(*client_, detail::QueryType::CHECK); buffer.appendValue(keys.size()); for (const std::string& key : keys) { buffer.appendString(keyPrefix_ + key); } buffer.flush(); auto response = client_->receiveValue(); if (response == detail::CheckResponseType::READY) { return true; } if (response == detail::CheckResponseType::NOT_READY) { return false; } TORCH_CHECK(false, "ready or not_ready response expected"); } void TCPStore::wait(const std::vector& keys) { wait(keys, timeout_); } void TCPStore::wait( const std::vector& keys, const std::chrono::milliseconds& timeout) { STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__wait); const std::lock_guard lock(activeOpLock_); std::vector prefixedKeys{}; prefixedKeys.reserve(keys.size()); for (const std::string& key : keys) { prefixedKeys.emplace_back(keyPrefix_ + key); } doWait(prefixedKeys, timeout); } void TCPStore::doWait( c10::ArrayRef keys, std::chrono::milliseconds timeout) { { detail::SendBuffer buffer(*client_, detail::QueryType::WAIT); buffer.appendValue(keys.size()); for (const std::string& key : keys) { buffer.appendString(key); } buffer.flush(); } detail::WaitResponseType response; if (client_->receiveValueWithTimeout( response, timeout)) { if (response != detail::WaitResponseType::STOP_WAITING) { TORCH_CHECK(false, "Stop_waiting response is expected"); } return; } // this is the cancel wait timeout, once here we expect the server to respond // in a timely fashion { detail::SendBuffer buffer(*client_, detail::QueryType::CANCEL_WAIT); buffer.flush(); } response = client_->receiveValue(); // this can happen if the server responds before we cancel, just ignore it if (response != detail::WaitResponseType::WAIT_CANCELED) { if (response != detail::WaitResponseType::STOP_WAITING) { TORCH_CHECK(false, "Stop_waiting response is expected"); } response = client_->receiveValue(); // ignore if (response != detail::WaitResponseType::WAIT_CANCELED) { TORCH_CHECK(false, "wait_canceled response is expected"); } } C10_THROW_ERROR( DistStoreError, fmt::format( "wait timeout after {}ms, keys: {}", timeout.count(), fmt::join(keys, ", "))); } void TCPStore::append( const std::string& key, const std::vector& data) { STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__append); const std::lock_guard lock(activeOpLock_); detail::SendBuffer buffer(*client_, detail::QueryType::APPEND); buffer.appendString(keyPrefix_ + key); buffer.appendBytes(data); buffer.flush(); } std::vector> TCPStore::multiGet( const std::vector& keys) { STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__multiGet); const std::lock_guard lock(activeOpLock_); std::vector prefixedKeys; prefixedKeys.reserve(keys.size()); for (const std::string& key : keys) { prefixedKeys.emplace_back(keyPrefix_ + key); } doWait(prefixedKeys, timeout_); detail::SendBuffer buffer(*client_, detail::QueryType::MULTI_GET); buffer.appendValue(keys.size()); for (auto& key : prefixedKeys) { buffer.appendString(key); } buffer.flush(); std::vector> result; result.reserve(keys.size()); for (size_t i = 0; i < keys.size(); ++i) { result.emplace_back(client_->receiveBits()); } return result; } void TCPStore::multiSet( const std::vector& keys, const std::vector>& values) { STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.TCPStore__multiSet); TORCH_CHECK( keys.size() == values.size(), "multiSet keys and values vectors must be of same size"); const std::lock_guard lock(activeOpLock_); detail::SendBuffer buffer(*client_, detail::QueryType::MULTI_SET); buffer.appendValue(keys.size()); for (auto i : c10::irange(keys.size())) { buffer.appendString(keyPrefix_ + keys[i]); buffer.appendBytes(values[i]); } buffer.flush(); } bool TCPStore::hasExtendedApi() const { return true; } std::string TCPStore::repr() const { auto clientRepr = client_ ? client_->repr() : ""; auto serverRepr = server_ ? server_->repr() : ""; return fmt::format("TCPStore(client={}, server={})", clientRepr, serverRepr); } } // namespace c10d