#include #include #include #include #include #include #include #include #include #include #include #include #include #include #ifdef TORCH_USE_LIBUV #include #endif namespace c10d::detail { #ifdef TORCH_USE_LIBUV /* Exception safety: It's ok to use exceptions during client processing. Other callbacks don't provide exception safety so avoid there. */ // This controls how many un-accepted TCP connections can be waiting in the // backlog. This should be at least world size to avoid issues on init. We set // it to -1 to use the host max value which is controlled by `soconnmax`. #define DEFAULT_BACKLOG -1 #define MAX_KEY_COUNT (128 * 1024) #define MAX_STRING_LEN (8 * 1024) #define MAX_PAYLOAD_LEN (8 * 1024 * 1024) // This controls the preferred size for buffers. // Too small and we'll need multiple buffers for one request // Too big and we might taxing malloc #define ALLOC_BUFFER_SIZE ((size_t)4000) class UvHandle : public c10::intrusive_ptr_target { public: ~UvHandle() override = default; c10::intrusive_ptr iptr() { return c10::intrusive_ptr::reclaim_copy(this); } void close() { if (uv_is_closing(unsafeGetHandle())) { return; } uv_close(unsafeGetHandle(), on_close); } virtual uv_handle_t* unsafeGetHandle() = 0; protected: void handleReady() { /* This method must be called once the handle is ready and registered with the loop. Do not call this in the ctor, make_intrusive reset refcounts to one after construction. */ uv_handle_set_data(unsafeGetHandle(), this); at::raw::intrusive_ptr::incref(this); } virtual void onClose() = 0; private: static c10::intrusive_ptr reclaim(uv_handle_t* handle) { auto h = (UvHandle*)uv_handle_get_data(handle); return c10::intrusive_ptr::reclaim(h); } static void on_close(uv_handle_t* uv_handle) { auto handle = reclaim(uv_handle); handle->onClose(); } }; class UvTcpSocket : public UvHandle { uv_tcp_t client{}; c10::intrusive_ptr iptr() { return c10::intrusive_ptr::reclaim_copy(this); } static c10::intrusive_ptr borrow(uv_stream_t* handle) { auto h = (UvTcpSocket*)uv_handle_get_data((uv_handle_t*)handle); return h->iptr(); } static void alloc_buffer( uv_handle_t* handle, size_t suggested_size, uv_buf_t* buf) { suggested_size = std::min(suggested_size, (size_t)ALLOC_BUFFER_SIZE); buf->base = (char*)malloc(suggested_size); buf->len = suggested_size; } static void read_callback( uv_stream_t* client, ssize_t nread, const uv_buf_t* buf) { auto uv_socket = UvTcpSocket::borrow(client); if (nread < 0) { C10D_DEBUG( "Read callback failed. code:{} name:{} desc:{}", nread, uv_err_name(nread), uv_strerror(nread)); uv_socket->close(); return; } if (nread > 0) { try { uv_socket->processBuf(buf, nread); } catch (std::exception& ex) { C10D_WARNING("Error processing client message: {}", ex.what()); uv_socket->close(); } } } public: explicit UvTcpSocket(uv_loop_t* loop) { uv_tcp_init(loop, &client); if (int err = uv_tcp_nodelay(&client, 1)) { C10D_WARNING( "The no-delay option cannot be enabled for the client socket. err={}", err); } } void startRead() { int res = uv_read_start((uv_stream_t*)&client, alloc_buffer, read_callback); if (res) { C10D_WARNING( "Failed to setup read callback. client:{} code:{} name:{} desc:{}.", (void*)this, res, uv_err_name(res), uv_strerror(res)); close(); } } uv_handle_t* unsafeGetHandle() override { return (uv_handle_t*)&client; } protected: uv_stream_t* unsafeGetStream() { return (uv_stream_t*)&client; } uv_tcp_t* unsafeGetSocket() { return &client; } virtual void processBuf(const uv_buf_t* buf, size_t nread) { TORCH_CHECK( false, "Trying to read from a socket subclass that lacks processBuf"); } void onClose() override { // TODO use registerClient (and rename it to registerHandle) - this will // significantly simplify things. } }; class UvTcpServer : public UvTcpSocket { public: typedef std::function OnConnectCallback; explicit UvTcpServer(uv_loop_t* loop) : UvTcpSocket(loop), onConnectCb(missingOnConnect) {} static c10::intrusive_ptr makeWithSocket( uv_loop_t* loop, int socket) { auto res = c10::make_intrusive(loop); res->handleReady(); try { int uv_res = uv_tcp_open((uv_tcp_t*)res->unsafeGetStream(), socket); TORCH_CHECK( uv_res == 0, "Failed to open existing socket. ", "socket: ", socket, ", code: ", uv_res, ", name: ", uv_err_name(uv_res), ", message: ", uv_strerror(uv_res)); res->cacheSocketPort(); } catch (std::exception& ex) { res->close(); throw; } return res; } void setOnConnectCallback(OnConnectCallback&& callback) { onConnectCb = std::move(callback); } static c10::intrusive_ptr makeWithPort( uv_loop_t* loop, uint16_t port, bool useIpv6) { auto res = c10::make_intrusive(loop); res->handleReady(); try { struct sockaddr_storage addr {}; int uv_res = 0; if (useIpv6) { uv_res = uv_ip6_addr("::", port, (struct sockaddr_in6*)&addr); } else { uv_res = uv_ip4_addr("0.0.0.0", port, (struct sockaddr_in*)&addr); } TORCH_CHECK( uv_res == 0, "UV Store addr parsing failure. ", "port: ", port, ", useIpv6: ", useIpv6, ", code: ", uv_res, ", name: ", uv_err_name(uv_res), ", message: ", uv_strerror(uv_res)); uv_res = uv_tcp_bind(res->unsafeGetSocket(), (const struct sockaddr*)&addr, 0); TORCH_CHECK( uv_res == 0, "The server socket has failed to bind. ", "port: ", port, ", useIpv6: ", useIpv6, ", code: ", uv_res, ", name: ", uv_err_name(uv_res), ", message: ", uv_strerror(uv_res)); uv_res = uv_listen(res->unsafeGetStream(), DEFAULT_BACKLOG, on_new_connection); TORCH_CHECK( uv_res == 0, "The server socket has failed to listen on any local network address. ", "port: ", port, ", useIpv6: ", useIpv6, ", code: ", uv_res, ", name: ", uv_err_name(uv_res), ", message: ", uv_strerror(uv_res)); res->cacheSocketPort(); } catch (std::exception& ex) { res->close(); throw; } return res; } uint16_t port() const { return portNum; } void accept(const c10::intrusive_ptr& socket) { int res = uv_accept(unsafeGetStream(), (uv_stream_t*)socket->unsafeGetHandle()); TORCH_CHECK( res == 0, "Failed to accept socket. ", "code: ", res, ", name: ", uv_err_name(res), ", message: ", uv_strerror(res)); } private: OnConnectCallback onConnectCb; uint16_t portNum{}; c10::intrusive_ptr iptr() { return c10::intrusive_ptr::reclaim_copy(this); } static c10::intrusive_ptr borrow(uv_stream_t* handle) { auto h = (UvTcpServer*)uv_handle_get_data((uv_handle_t*)handle); return h->iptr(); } void cacheSocketPort() { sockaddr_storage addr_s{}; int addr_len = sizeof(addr_s); if (uv_tcp_getsockname( (uv_tcp_t*)unsafeGetStream(), reinterpret_cast(&addr_s), &addr_len) != 0) { throw std::runtime_error( "The port number of the socket cannot be retrieved."); } if (addr_s.ss_family == AF_INET) { portNum = ntohs(reinterpret_cast(&addr_s)->sin_port); } else { portNum = ntohs(reinterpret_cast(&addr_s)->sin6_port); } } static void missingOnConnect(int status) { TORCH_CHECK(false, "Socket accepted byt onConnect callback missing"); } static void on_new_connection(uv_stream_t* server, int status) { borrow(server)->onConnectCb(status); } }; class WriterPayload : public c10::intrusive_ptr_target { static c10::intrusive_ptr reclaim(uv_write_t* request) { /* This method returns a intrusive_ptr that does not increase the refcount. */ auto h = (WriterPayload*)uv_req_get_data((uv_req_t*)request); return c10::intrusive_ptr::reclaim(h); } void registeredInLoop() { /* This refcount increment must be matched by a reclaim call. Call this method after sucessfully scheduling this handle with a loop. */ at::raw::intrusive_ptr::incref(this); } static void write_done(uv_write_t* req, int status) { /* Since we're no longer actively used by the event loop, transfer ownership * to this frame. */ auto wp = WriterPayload::reclaim(req); auto handle = wp->handle; if (status) { C10D_WARNING( "Write to client failed. code:{} name:{} desc:{}.", status, uv_err_name(status), uv_strerror(status)); handle->close(); } } std::vector data; uv_write_t req = {}; uv_buf_t buf = {}; c10::intrusive_ptr handle; public: WriterPayload( std::vector&& in_data, c10::intrusive_ptr handle) : data(std::move(in_data)), handle(std::move(handle)) { uv_req_set_data((uv_req_t*)&req, this); } ~WriterPayload() override = default; void send() { buf = uv_buf_init((char*)data.data(), data.size()); int res = uv_write( &req, (uv_stream_t*)handle->unsafeGetHandle(), &buf, 1, write_done); if (res) { C10D_WARNING( "Write setup to client failed. code:{} name:{} desc:{}.", res, uv_err_name(res), uv_strerror(res)); handle->close(); } else { /* This object was successfully registered with the event loop, so keep it * alive until it's unregistered. */ registeredInLoop(); } } }; class StreamWriter { std::vector data; c10::intrusive_ptr handle; // must be stack allocated void* operator new(size_t); public: StreamWriter(c10::intrusive_ptr handle) : handle(std::move(handle)) {} void write1(uint8_t val) { data.push_back(val); } template void write_value(T val) { uint8_t* val_ptr = (uint8_t*)&val; data.insert(data.end(), val_ptr, val_ptr + sizeof(T)); } void write_vector(const std::vector& val) { write_value(val.size()); data.insert(data.end(), val.begin(), val.end()); } void write_string(const std::string& val) { write_value(val.size()); data.insert(data.end(), val.data(), val.data() + val.size()); } void send() { auto wd = c10::make_intrusive(std::move(data), handle); wd->send(); } }; class ChunkedStream { std::deque buffers; size_t buff_idx{0}; size_t buff_offset{0}; size_t capacity{0}; size_t buff_offset_commit{0}; size_t read_offset{0}; public: ChunkedStream() = default; size_t buf_count() { return buffers.size(); } void append(uv_buf_t buf) { if (buf.len == 0) { free(buf.base); } else { capacity += buf.len; buffers.push_back(buf); } } bool read_many(char* dest, size_t size) { if (available() < size) { return false; } size_t remaining = size; char* write_base = dest; while (remaining > 0) { auto to_read = std::min(buffers[buff_idx].len - buff_offset, remaining); ::memcpy(write_base, buffers[buff_idx].base + buff_offset, to_read); buff_offset += to_read; remaining -= to_read; write_base += to_read; if (buff_offset >= buffers[buff_idx].len) { buff_offset = 0; ++buff_idx; if (buff_idx >= buffers.size() && remaining > 0) { TORCH_CHECK( false, "Trying to read past end of buffer. ", "buffer_idx: ", buff_idx, ", available: ", buffers.size(), ", remaining: ", remaining); } } } read_offset += size; return true; } bool read1(uint8_t& byte) { while (true) { if (buff_idx >= buffers.size()) return false; if (buff_offset >= buffers[buff_idx].len) { buff_offset = 0; ++buff_idx; continue; } break; } byte = buffers[buff_idx].base[buff_offset]; ++buff_offset; ++read_offset; return true; } template bool read_value(T& value) { return read_many((char*)&value, sizeof(T)); } bool read_key(std::string& str) { uint64_t size = 0; if (!read_value(size)) return false; TORCH_CHECK( size <= MAX_STRING_LEN, "Invalid string size. ", "size: ", size, ", max: ", MAX_STRING_LEN); if (available() < size) return false; str.resize(size); return read_many((char*)str.data(), size); } bool read_payload(std::vector& data) { uint64_t size = 0; if (!read_value(size)) return false; auto size_in_bytes = size * sizeof(uint8_t); TORCH_CHECK( size_in_bytes <= MAX_PAYLOAD_LEN, "Invalid payload size. ", "size: ", size_in_bytes, ", max: ", MAX_PAYLOAD_LEN); if (available() < size_in_bytes) return false; data.resize(size); return read_many((char*)data.data(), size_in_bytes); } size_t available() { return capacity - read_offset; } void commit() { if (buff_idx >= buffers.size() || buff_offset >= buffers[buff_idx].len) { buff_offset = 0; if (buff_idx < buffers.size()) ++buff_idx; } for (size_t i = 0; i < buff_idx; ++i) { free(buffers[0].base); capacity -= buffers[0].len; buffers.pop_front(); } buff_idx = 0; read_offset = buff_offset_commit = buff_offset; } void reset() { buff_idx = 0; read_offset = buff_offset = buff_offset_commit; } }; class LibUVStoreDaemon : public BackgroundThread { public: explicit LibUVStoreDaemon(int port); ~LibUVStoreDaemon() override; uint16_t port() const override; void set(const std::string& key, const std::vector& value); const std::vector& compareAndSet( const std::string& key, const std::vector& expectedValue, const std::vector& newValue); const std::vector& get(const std::string& key); int64_t add(const std::string& key, int64_t addVal); bool checkKeys(const std::vector& keys); bool waitKeys( const std::vector& keys, const c10::intrusive_ptr& client); int64_t size(); int64_t deleteKey(const std::string& key); void append(const std::string& key, const std::vector& value); void registerClient(const c10::intrusive_ptr& client); void unregisterClient(const c10::intrusive_ptr& client); void clearClientWaitState(const c10::intrusive_ptr& client); bool isMiscellaneousClient(const c10::intrusive_ptr& client); uint16_t get_socket_port(uv_tcp_t* handle); void init(const TCPStoreOptions& opts); protected: void run() override; void stop() override; private: uv_loop_t loop{}; c10::intrusive_ptr tcpServer; uv_async_t exit_handle{}; std::unordered_map> tcpStore_; // From key -> the list of UvClient waiting on the key std::unordered_map>> waitingSockets_; // From socket -> number of keys awaited std::unordered_map, size_t> keysAwaited_; std::unordered_set> clients_; std::unordered_set> miscellaneousClients_; int port_; static LibUVStoreDaemon& from_uv(uv_handle_t* stream) { return *(LibUVStoreDaemon*)uv_handle_get_data(stream); } static void on_new_connection(uv_stream_t* server, int status) { from_uv((uv_handle_t*)server).onConnect(status); } static void on_exit_request(uv_async_t* handle) { from_uv((uv_handle_t*)handle).onExitRequest(); } void onConnect(int status); void onExitRequest(); void wakeupWaitingClients(const std::string& key); // bool tryListen(bool use_ipv6); static void print_active_handles(uv_handle_t* handle, void* arg); }; class UvClient : public UvTcpSocket { ChunkedStream stream; LibUVStoreDaemon* store; protected: void processBuf(const uv_buf_t* buf, size_t nread) override { auto tmp = *buf; tmp.len = nread; stream.append(tmp); while (true) { stream.reset(); uint8_t command = -1; if (!stream.read1(command)) break; if (store->isMiscellaneousClient(iptr())) { if ((QueryType)command != QueryType::VALIDATE) return; if (!parse_validate_command()) return; } else { switch ((QueryType)command) { case QueryType::PING: if (!parse_ping_command()) return; break; case QueryType::SET: if (!parse_set_command()) return; break; case QueryType::COMPARE_SET: if (!parse_compare_set_command()) return; break; case QueryType::GET: if (!parse_get_command()) return; break; case QueryType::ADD: if (!parse_add_command()) return; break; case QueryType::CHECK: if (!parse_check_command()) return; break; case QueryType::WAIT: if (!parse_wait_command()) return; break; case QueryType::GETNUMKEYS: if (!parse_getnumkeys_command()) return; break; case QueryType::DELETE_KEY: if (!parse_delete_key_command()) return; break; case QueryType::APPEND: if (!parse_append_command()) return; break; case QueryType::MULTI_GET: if (!parse_multi_get_command()) return; break; case QueryType::MULTI_SET: if (!parse_multi_set_command()) return; break; case QueryType::CANCEL_WAIT: if (!parse_cancel_wait_command()) return; break; default: C10D_DEBUG( "Client sent invalid command. client:{} command:{}", (void*)this, (int)command); close(); return; } } stream.commit(); } } bool parse_validate_command() { uint32_t validateNumber = 0; if (!stream.read_value(validateNumber)) return false; if (validateNumber != c10d::detail::validationMagicNumber) return false; return true; } bool parse_ping_command() { uint32_t nonce; if (!stream.read_value(nonce)) { return false; } StreamWriter sw(iptr()); sw.write_value(nonce); sw.send(); return true; } bool parse_set_command() { std::string key; if (!stream.read_key(key)) return false; std::vector newData; if (!stream.read_payload(newData)) return false; store->set(key, newData); return true; } bool parse_compare_set_command() { std::string key; if (!stream.read_key(key)) return false; std::vector currentValue; if (!stream.read_payload(currentValue)) return false; std::vector newValue; if (!stream.read_payload(newValue)) return false; auto res = store->compareAndSet(key, currentValue, newValue); StreamWriter sw(iptr()); sw.write_vector(res); sw.send(); return true; } bool parse_get_command() { std::string key; if (!stream.read_key(key)) return false; const auto& data = store->get(key); StreamWriter sw(iptr()); sw.write_vector(data); sw.send(); return true; } bool parse_add_command() { std::string key; if (!stream.read_key(key)) return false; int64_t addVal = 0; if (!stream.read_value(addVal)) return false; addVal = store->add(key, addVal); StreamWriter sw(iptr()); sw.write_value(addVal); sw.send(); return true; } bool parse_check_command() { uint64_t key_count = 0; if (!stream.read_value(key_count)) return false; TORCH_CHECK( key_count <= MAX_KEY_COUNT, "Too many keys being waited. ", "keys: ", key_count, ", max: ", MAX_KEY_COUNT); std::vector keys(key_count); for (uint64_t i = 0; i < key_count; ++i) { if (!stream.read_key(keys[i])) return false; } // Now we have received all the keys StreamWriter sw(iptr()); if (store->checkKeys(keys)) { sw.write_value(CheckResponseType::READY); } else { sw.write_value(CheckResponseType::NOT_READY); } sw.send(); return true; } bool parse_wait_command() { uint64_t key_count = 0; if (!stream.read_value(key_count)) { return false; } TORCH_CHECK( key_count <= MAX_KEY_COUNT, "Too many keys being waited. ", "keys: ", key_count, ", max: ", MAX_KEY_COUNT); std::vector keys(key_count); for (uint64_t i = 0; i < key_count; ++i) { if (!stream.read_key(keys[i])) return false; } if (store->waitKeys(keys, iptr())) { StreamWriter sw(iptr()); sw.write1((uint8_t)WaitResponseType::STOP_WAITING); sw.send(); } return true; } bool parse_getnumkeys_command() { StreamWriter sw(iptr()); sw.write_value(store->size()); sw.send(); return true; } bool parse_delete_key_command() { std::string key; if (!stream.read_key(key)) return false; auto numDeleted = store->deleteKey(key); StreamWriter sw(iptr()); sw.write_value(numDeleted); sw.send(); return true; } bool parse_append_command() { std::string key; if (!stream.read_key(key)) { return false; } std::vector data; if (!stream.read_payload(data)) { return false; } store->append(key, data); return true; } bool parse_multi_get_command() { uint64_t key_count = 0; if (!stream.read_value(key_count)) { return false; } TORCH_CHECK( key_count <= MAX_KEY_COUNT, "Too many keys with multi_get. ", "keys: ", key_count, ", max: ", MAX_KEY_COUNT); StreamWriter sw(iptr()); for (const auto _ : c10::irange(key_count)) { (void)_; // Suppress unused variable warning std::string key; if (!stream.read_key(key)) { return false; } sw.write_vector(store->get(key)); } sw.send(); return true; } bool parse_multi_set_command() { uint64_t key_count = 0; if (!stream.read_value(key_count)) { return false; } TORCH_CHECK( key_count <= MAX_KEY_COUNT, "Too many keys with multi_get. ", "keys: ", key_count, ", max: ", MAX_KEY_COUNT); for (const auto _ : c10::irange(key_count)) { (void)_; // Suppress unused variable warning std::string key; if (!stream.read_key(key)) { return false; } std::vector newData; if (!stream.read_payload(newData)) return false; store->set(key, newData); } return true; } bool parse_cancel_wait_command() { store->clearClientWaitState(iptr()); StreamWriter sw(iptr()); sw.write1((uint8_t)WaitResponseType::WAIT_CANCELED); sw.send(); return true; } public: explicit UvClient(uv_loop_t* loop, LibUVStoreDaemon* store) : UvTcpSocket(loop), store(store) {} static c10::intrusive_ptr make( uv_loop_t* loop, LibUVStoreDaemon* store) { auto res = c10::make_intrusive(loop, store); res->handleReady(); return res; } c10::intrusive_ptr iptr() { return c10::intrusive_ptr::reclaim_copy(this); } protected: void onClose() override { store->unregisterClient(iptr()); } }; void LibUVStoreDaemon::onConnect(int status) { auto client = UvClient::make(&loop, this); registerClient(client); try { tcpServer->accept(client); client->startRead(); } catch (std::exception& e) { C10D_WARNING("Failed to accept client due to {}", e.what()); client->close(); } } void LibUVStoreDaemon::onExitRequest() { C10D_DEBUG("Store exit requested\n"); uv_close((uv_handle_t*)&exit_handle, nullptr); uv_stop(&loop); } void LibUVStoreDaemon::init(const TCPStoreOptions& opts) { if (opts.masterListenFd.has_value()) { tcpServer = UvTcpServer::makeWithSocket(&loop, *opts.masterListenFd); } else { try { tcpServer = UvTcpServer::makeWithPort(&loop, opts.port, /*useIpv6=*/true); } catch (std::exception& ex) { C10D_INFO( "Failed to bind to ipv6 address, trying ipv4. Error: {}", ex.what()); tcpServer = UvTcpServer::makeWithPort(&loop, opts.port, /*useIpv6=*/false); } } tcpServer->setOnConnectCallback( [this](auto status) { this->onConnect(status); }); port_ = tcpServer->port(); TORCH_CHECK( port_ == opts.port || opts.port == 0, // zero means use any port "listen fd ", *opts.masterListenFd, " is bound to port ", port_, ", expected to be bound to port ", opts.port); } LibUVStoreDaemon::LibUVStoreDaemon(int port) : port_(port) { TORCH_CHECK(uv_loop_init(&loop) == 0, "Failed to init uv loop"); TORCH_CHECK( uv_async_init(&loop, &exit_handle, LibUVStoreDaemon::on_exit_request) == 0, "Failed to init uv async event"); uv_handle_set_data((uv_handle_t*)&exit_handle, this); } LibUVStoreDaemon::~LibUVStoreDaemon() { if (!is_running()) { uv_close((uv_handle_t*)&exit_handle, nullptr); uv_run(&loop, UV_RUN_NOWAIT); TORCH_CHECK(uv_loop_close(&loop) == 0, "loop cleanup didn't work"); } else { // the daemon thread cleanup libuv dispose(); } } uint16_t LibUVStoreDaemon::port() const { return port_; } void LibUVStoreDaemon::print_active_handles(uv_handle_t* handle, void* arg) { C10D_DEBUG( "UV live handle type {} active:{} is-closing:{}", (int)handle->type, uv_is_active(handle), uv_is_closing(handle)); } void LibUVStoreDaemon::run() { c10::setThreadName("pt_tcpstore_uv"); C10D_DEBUG("Uv main loop running"); int res = uv_run(&loop, UV_RUN_DEFAULT); if (res) { C10D_DEBUG("UV main loop done: res:{}", res); } bool debug_enabled = c10d::detail::isLogLevelEnabled(c10d::detail::LogLevel::Debug); if (debug_enabled) { C10D_DEBUG("Walking live handles prior to closing clients"); uv_walk(&loop, LibUVStoreDaemon::print_active_handles, nullptr); } for (const auto& client : clients_) { client->close(); } tcpServer->close(); if (debug_enabled) { C10D_DEBUG("Walking live handles after closing clients"); uv_walk(&loop, LibUVStoreDaemon::print_active_handles, nullptr); } while (true) { res = uv_loop_close(&loop); if (res == 0) { break; } C10D_INFO( "uv_loop_close failed with:{} errn:{} desc:{}", res, uv_err_name(res), uv_strerror(res)); res = uv_run(&loop, UV_RUN_NOWAIT); if (res != 0) { std::this_thread::sleep_for(std::chrono::milliseconds(500)); } } C10D_INFO("uv_loop cleanup finished."); } void LibUVStoreDaemon::stop() { int res = uv_async_send(&exit_handle); if (res) { C10D_WARNING( "uv_async_send failed with:{} errn:{} desc:{}\n", res, uv_err_name(res), uv_strerror(res)); } } bool LibUVStoreDaemon::isMiscellaneousClient( const c10::intrusive_ptr& client) { if (miscellaneousClients_.find(client) != miscellaneousClients_.end()) { miscellaneousClients_.erase(client); return true; } return false; } void LibUVStoreDaemon::registerClient( const c10::intrusive_ptr& client) { clients_.insert(client); miscellaneousClients_.insert(client); } void LibUVStoreDaemon::unregisterClient( const c10::intrusive_ptr& client) { clients_.erase(client); if (miscellaneousClients_.find(client) != miscellaneousClients_.end()) { miscellaneousClients_.erase(client); } clearClientWaitState(client); } void LibUVStoreDaemon::clearClientWaitState( const c10::intrusive_ptr& client) { if (keysAwaited_.find(client) == keysAwaited_.end()) { return; } keysAwaited_.erase(client); for (auto it = waitingSockets_.begin(); it != waitingSockets_.end();) { for (auto vecIt = it->second.begin(); vecIt != it->second.end();) { if (*vecIt == client) { vecIt = it->second.erase(vecIt); } else { ++vecIt; } } if (it->second.empty()) { it = waitingSockets_.erase(it); } else { ++it; } } } void LibUVStoreDaemon::set( const std::string& key, const std::vector& value) { tcpStore_[key] = value; // On "set", wake up all clients that have been waiting wakeupWaitingClients(key); } const std::vector& LibUVStoreDaemon::compareAndSet( const std::string& key, const std::vector& expectedValue, const std::vector& newValue) { auto pos = tcpStore_.find(key); if (pos == tcpStore_.end()) { if (expectedValue.empty()) { tcpStore_[key] = newValue; wakeupWaitingClients(key); return newValue; } else { // TODO: This code path is not ideal as we are "lying" to the caller in // case the key does not exist. We should come up with a working solution. // It might make more sense to return "" wakeupWaitingClients(key); return expectedValue; } } else { if (pos->second == expectedValue) { pos->second = newValue; } wakeupWaitingClients(key); return pos->second; } } const std::vector& LibUVStoreDaemon::get(const std::string& key) { static std::vector missing_key; return tcpStore_.count(key) ? tcpStore_.at(key) : missing_key; } int64_t LibUVStoreDaemon::add(const std::string& key, int64_t addVal) { std::vector oldData; auto it = tcpStore_.find(key); if (it != tcpStore_.end()) { oldData = it->second; auto buf = reinterpret_cast(it->second.data()); auto len = it->second.size(); addVal += std::stoll(std::string(buf, len)); } auto addValStr = std::to_string(addVal); std::vector newData = std::vector(addValStr.begin(), addValStr.end()); tcpStore_[key] = newData; // On "add", wake up all clients that have been waiting wakeupWaitingClients(key); return addVal; } bool LibUVStoreDaemon::checkKeys(const std::vector& keys) { return std::all_of(keys.begin(), keys.end(), [&](const std::string& s) { return tcpStore_.count(s) > 0; }); } bool LibUVStoreDaemon::waitKeys( const std::vector& keys, const c10::intrusive_ptr& client) { if (checkKeys(keys)) { return true; } int numKeysToAwait = 0; for (auto& key : keys) { // Only count keys that have not already been set if (tcpStore_.find(key) == tcpStore_.end()) { waitingSockets_[key].push_back(client); numKeysToAwait++; } } keysAwaited_[client] = numKeysToAwait; return false; } int64_t LibUVStoreDaemon::size() { return tcpStore_.size(); } int64_t LibUVStoreDaemon::deleteKey(const std::string& key) { return tcpStore_.erase(key); } void LibUVStoreDaemon::append( const std::string& key, const std::vector& value) { std::vector oldData; auto it = tcpStore_.find(key); if (it != tcpStore_.end()) { it->second.insert(it->second.end(), value.begin(), value.end()); } else { tcpStore_[key] = value; } // we should not have clients waiting if we're appending, so it's all fine wakeupWaitingClients(key); } void LibUVStoreDaemon::wakeupWaitingClients(const std::string& key) { auto socketsToWait = waitingSockets_.find(key); if (socketsToWait != waitingSockets_.end()) { for (const auto& client : socketsToWait->second) { if (--keysAwaited_[client] == 0) { StreamWriter sw(client->iptr()); sw.write1((uint8_t)WaitResponseType::STOP_WAITING); sw.send(); } } waitingSockets_.erase(socketsToWait); } } #endif std::unique_ptr create_libuv_tcpstore_backend( const TCPStoreOptions& opts) { #ifdef TORCH_USE_LIBUV auto res = std::make_unique(opts.port); res->init(opts); return res; #else TORCH_CHECK(false, "LibUV TCPStore implementation missing"); #endif } bool is_libuv_tcpstore_backend_available() { #ifdef TORCH_USE_LIBUV return true; #else return false; #endif } } // namespace c10d::detail