#include #include namespace torch::distributed::rpc { Message::Message() = default; Message::Message( std::vector&& payload, std::vector&& tensors, MessageType type) : payload_(std::move(payload)), tensors_(std::move(tensors)), type_(type) {} Message::Message( std::vector&& payload, std::vector&& tensors, MessageType type, int64_t id) : payload_(std::move(payload)), tensors_(std::move(tensors)), type_(type), id_(id) {} std::vector&& Message::movePayload() && { return std::move(payload_); } std::vector& Message::payload() { return payload_; } const std::vector& Message::payload() const { return payload_; } std::vector&& Message::moveTensors() && { return std::move(tensors_); } std::vector& Message::tensors() { return tensors_; } const std::vector& Message::tensors() const { return tensors_; } MessageType Message::type() const { return type_; } bool Message::isRequest() const { return MessageTypeFlags::REQUEST_TYPE & type_; } bool Message::isResponse() const { return MessageTypeFlags::RESPONSE_TYPE & type_; } int64_t Message::id() const { return id_; } void Message::setId(int64_t id) { id_ = id; } std::vector> Message::getStorages() const { // Sparse tensors do not have storage. Instead, a sparse tensor // contains two tensors indices and values, and both contain storage. std::vector> storages; storages.reserve(2 * tensors_.size()); for (const auto& tensor : tensors_) { if (tensor.is_sparse()) { storages.emplace_back(tensor._indices().storage().getWeakStorageImpl()); storages.emplace_back(tensor._values().storage().getWeakStorageImpl()); } else { storages.emplace_back(tensor.storage().getWeakStorageImpl()); } } return storages; } c10::intrusive_ptr createExceptionResponse( const std::exception& e, int64_t id) { return createExceptionResponse(e.what(), id); } c10::intrusive_ptr createExceptionResponse( const std::string& exceptionStr, int64_t id) { std::vector payload(exceptionStr.begin(), exceptionStr.end()); return c10::make_intrusive( std::move(payload), std::vector(), MessageType::EXCEPTION, id); } namespace { // NB: need to call torch::class_ to register Message in the map returned by // c10::getCustomClassTypeMap(). Otherwise, Message cannot be wrapped within // an IValue. // NB: add this line here instead of in rpc/init.cpp because 1) we have C++ // only tests that won't run rpc/init.cpp; 2) Message is not meant to be // visible from Python. static const auto message = torch::class_("rpc", "_Message"); } // namespace } // namespace torch::distributed::rpc