#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include using namespace torch::autograd::profiler; namespace torch::distributed::rpc { namespace { void processRemoteProfiledEvents( autograd::RpcWithProfilingResp& rpcWithProfilingResp) { // Check if the profiler is enabled auto enabled = profilerEnabled(); TORCH_CHECK( enabled, "Profiler was expected to be enabled. This can happen in callback " " continuations that run in different threads, and the TLS of the " " profiler was not propagated."); std::vector events = rpcWithProfilingResp.getProfiledEvents(); const auto& profilingId = rpcWithProfilingResp.getProfilingId(); auto& remoteProfilerManager = RemoteProfilerManager::getInstance(); auto key = remoteProfilerManager.retrieveRPCProfilingKey(profilingId); remoteProfilerManager.eraseKey(profilingId); auto keyPrefixStr = key + rpc::REMOTE_PROFILING_KEY_PREFIX; std::for_each( events.begin(), events.end(), [&keyPrefixStr](LegacyEvent& event) { std::string name = keyPrefixStr + std::string(event.name()); event.setName(at::StringView(name)); }); // Add event list to the thread local profiler. addEventList(std::move(events)); } } // namespace const std::string kRPCErrorPrefix = std::string("RPCErr"); RPCErrorType getRPCErrorType(const JitFuture& jitFuture) { TORCH_INTERNAL_ASSERT( jitFuture.hasError(), "JitFuture of Message passed to getRPCErrorType does not have an error."); // Attempt to parse for error string given by makeRPCError, otherwise return // unknown error. // Note that this function expects errors formatted with makeRPCError(). auto err = jitFuture.tryRetrieveErrorMessage(); size_t pos = err.find(kRPCErrorPrefix); if (pos != std::string::npos) { // Parse the RPCErrorType. auto errStartIdx = pos + torch::distributed::rpc::kRPCErrorPrefix.size() + 1; auto errEndIdx = err.find(':', errStartIdx); if (errEndIdx == std::string::npos) { // Indicates error was not formatted correctly. return RPCErrorType::UNKNOWN_ERROR; } auto errStr = err.substr(errStartIdx, errEndIdx - errStartIdx); auto errType = static_cast(std::stoi(errStr)); return errType; } else { return RPCErrorType::UNKNOWN_ERROR; } } std::string makeRPCError( const std::string& rpcErrorStr, RPCErrorType errorType) { return fmt::format( "{}:{}:{}", torch::distributed::rpc::kRPCErrorPrefix, static_cast(errorType), rpcErrorStr); } std::unique_ptr deserializeRequest(const Message& request) { switch (request.type()) { case MessageType::SCRIPT_CALL: { return ScriptCall::fromMessage(request); } case MessageType::PYTHON_CALL: { return PythonCall::fromMessage(request); } case MessageType::SCRIPT_REMOTE_CALL: { return ScriptRemoteCall::fromMessage(request); } case MessageType::PYTHON_REMOTE_CALL: { return PythonRemoteCall::fromMessage(request); } case MessageType::SCRIPT_RREF_FETCH_CALL: { return ScriptRRefFetchCall::fromMessage(request); } case MessageType::PYTHON_RREF_FETCH_CALL: { return PythonRRefFetchCall::fromMessage(request); } case MessageType::RREF_USER_DELETE: { return RRefUserDelete::fromMessage(request); } case MessageType::RREF_CHILD_ACCEPT: { return RRefChildAccept::fromMessage(request); } case MessageType::RREF_FORK_REQUEST: { return RRefForkRequest::fromMessage(request); } case MessageType::FORWARD_AUTOGRAD_REQ: { return autograd::RpcWithAutograd::fromMessage(request); } case MessageType::BACKWARD_AUTOGRAD_REQ: { return autograd::PropagateGradientsReq::fromMessage(request); } case MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ: { return autograd::CleanupAutogradContextReq::fromMessage(request); } case MessageType::RUN_WITH_PROFILING_REQ: { return autograd::RpcWithProfilingReq::fromMessage(request); } case MessageType::RREF_BACKWARD_REQ: { return autograd::RRefBackwardReq::fromMessage(request); } default: { TORCH_INTERNAL_ASSERT( false, "Request type ", request.type(), " not supported."); } } } std::unique_ptr deserializeResponse( const Message& response, MessageType& wrappedMsgType) { switch (response.type()) { case MessageType::SCRIPT_RET: { return ScriptResp::fromMessage(response); } case MessageType::PYTHON_RET: { return PythonResp::fromMessage(response); } case MessageType::REMOTE_RET: { return RemoteRet::fromMessage(response); } case MessageType::SCRIPT_RREF_FETCH_RET: { return ScriptRRefFetchRet::fromMessage(response); } case MessageType::PYTHON_RREF_FETCH_RET: { return PythonRRefFetchRet::fromMessage(response); } case MessageType::RREF_ACK: { return RRefAck::fromMessage(response); } case MessageType::FORWARD_AUTOGRAD_RESP: { std::unique_ptr rpcPtr = autograd::RpcWithAutograd::fromMessage(response); RpcCommandBase& rpc = *rpcPtr; auto& rpcWithAutograd = static_cast(rpc); // Need to reverse the device map for the backward pass of distributed // autograd. DeviceMap reverseDeviceMap; for (const auto& mapEntry : rpcWithAutograd.deviceMap()) { reverseDeviceMap.insert({mapEntry.second, mapEntry.first}); } // Attach 'recv' autograd function. addRecvRpcBackward( rpcWithAutograd.autogradMetadata(), rpcWithAutograd.tensors(), rpcWithAutograd.fromWorkerId(), reverseDeviceMap); wrappedMsgType = rpcWithAutograd.wrappedMessageType(); return std::move(rpcWithAutograd).moveWrappedRpc(); } case MessageType::BACKWARD_AUTOGRAD_RESP: { return autograd::PropagateGradientsResp::fromMessage(response); } case MessageType::CLEANUP_AUTOGRAD_CONTEXT_RESP: { return autograd::CleanupAutogradContextResp::fromMessage(response); } case MessageType::RUN_WITH_PROFILING_RESP: { std::unique_ptr rpcPtr = autograd::RpcWithProfilingResp::fromMessage(response); RpcCommandBase& rpc = *rpcPtr; auto& rpcWithProfilingResp = static_cast(rpc); // Process remotely profiled events. processRemoteProfiledEvents(rpcWithProfilingResp); wrappedMsgType = rpcWithProfilingResp.wrappedMessageType(); auto wrappedRPC = std::move(rpcWithProfilingResp).moveWrappedRpc(); return wrappedRPC; } case MessageType::RREF_BACKWARD_RESP: { return autograd::RRefBackwardResp::fromMessage(response); } default: { TORCH_INTERNAL_ASSERT( false, "Response type ", response.type(), " not supported."); } } } IValue deserializeResptoIValueInternal( RpcCommandBase& rpc, MessageType messageType) { switch (messageType) { case MessageType::SCRIPT_RET: { auto& ret = static_cast(rpc); return ret.value(); } default: { TORCH_INTERNAL_ASSERT( false, "Response type ", messageType, " is not supported to be deserialized to IValue."); } } } IValue deserializeRespToIValue(const Message& message) { MessageType msgType = message.type(); auto response = deserializeResponse(message, msgType); return deserializeResptoIValueInternal(*response, msgType); } namespace { // Helper for wireDeserialize() below. // // The format we use below looks like: // section_name_1 size_1\n // section_name_2 size_2\n // .. // \n // [sections in order] // // Sections themselves include: // - "payload" - the payload bits // - "meta" - metadata for the unpickler // - "0" ... - tensor sections for the unpickler // // Note that per the header comments, the format is subject to change, // and is best used for rpcs, rather than persistent disk storage. std::unordered_map> parseWireSections(const void* data, size_t data_size) { const char* ptr = static_cast(data); const char* endp = ptr + data_size; std::vector> headerEnts; bool ok = false; while (ptr != endp) { if (*ptr == '\n') { ok = true; // The only "correct" exit point. ++ptr; break; } // Parse name const char* namePtr = ptr; while (ptr != endp && *ptr != ' ') { ptr++; } if (ptr == endp) { break; } std::string name(namePtr, ptr - namePtr); if (++ptr == endp) { break; // past the ' ' } // Parse size const char* sizePtr = ptr; while (ptr != endp && *ptr != '\n') { ptr++; } if (ptr == endp) { break; } size_t sz = std::stoll(std::string(sizePtr, ptr - sizePtr)); headerEnts.emplace_back(name, sz); ++ptr; // past the '\n' } if (!ok) { TORCH_CHECK(false, "failed parse"); } std::unordered_map> out; for (const auto& headerEnt : headerEnts) { out[headerEnt.first] = {ptr, headerEnt.second}; ptr += headerEnt.second; } if (ptr != endp) { TORCH_CHECK(false, "failed bounds"); } return out; } static const char* kMeta = "meta"; static const char* kPayload = "payload"; }; // namespace c10::List cloneSparseTensors( const std::vector& tensors) { // Sanity-check: If the majority of bits don't need to go over the wire, // force a clone(). Some Tensors are effectively small views, only using // ~1% of the underlying Storage. auto worthRecopying = [](const at::Tensor& t) -> bool { if (!t.has_storage()) { return false; // avoid throwing below. } auto storageSize = t.storage().nbytes(); auto usefulSize = t.element_size() * t.numel(); constexpr size_t kMinMultiple = 2; constexpr size_t kMinRecopyBytes = 8 * 1024; return storageSize >= kMinRecopyBytes && storageSize >= usefulSize * kMinMultiple; }; c10::List pTensors; pTensors.reserve(tensors.size()); for (const auto& t : tensors) { pTensors.push_back(worthRecopying(t) ? t.clone() : t); } return pTensors; } std::string wireSerialize( const std::vector& payload, const std::vector& tensors) { for (const auto& tensor : tensors) { TORCH_CHECK( tensor.device().is_cpu(), "ProcessGroup RPC backend only supports", " CPU tensors, please move your tensors to CPU before sending ", "them over RPC. Found tensor on device: ", tensor.device()); } struct Ent { std::string name; const char* data; size_t size; }; std::vector entries; std::string metaEntry; std::vector tensorData; if (!payload.empty()) { entries.push_back({kPayload, payload.data(), payload.size()}); } if (!tensors.empty()) { torch::jit::Pickler pickler([&](const void* buf, size_t sz) -> size_t { metaEntry.append(static_cast(buf), sz); return sz; }); pickler.protocol(); pickler.pushIValue(cloneSparseTensors(tensors)); pickler.stop(); tensorData = pickler.tensorData(); entries.push_back({kMeta, metaEntry.data(), metaEntry.size()}); for (const auto i : c10::irange(tensorData.size())) { // Construct WritableTensorData for each tensor in the pickler tensorData // Since tensorData is in function scope, and getWritableTensorData just // record the tensors, the data() pointers stay valid for CPU tensors // Note that RPC serde doesn't support CUDA tensors yet, if we should // support CUDA tensor, we need to be careful since getWritableTensorData // converts CUDA tensor to cpu and data() might get destructed as we go // out of scope of this loop. auto writeableTensorData = jit::getWriteableTensorData(tensorData[i]); entries.push_back( {std::to_string(i), writeableTensorData.data(), writeableTensorData.sizeInBytes()}); } } std::string header; size_t tot = 0; for (const auto& e : entries) { tot += e.size; header.append(e.name) .append(" ") .append(std::to_string(e.size)) .append("\n"); } header.push_back('\n'); std::string out; out.reserve(header.size() + tot); out.append(header); for (const auto& e : entries) { out.append(e.data, e.size); } return out; } std::pair, std::vector> wireDeserialize( const void* data, size_t data_size) { auto sections = parseWireSections(data, data_size); std::vector payload; auto payloadIt = sections.find(kPayload); if (payloadIt != sections.end() && payloadIt->second.second != 0) { payload.assign( payloadIt->second.first, payloadIt->second.first + payloadIt->second.second); } std::vector tensors; auto metaIt = sections.find(kMeta); if (metaIt != sections.end()) { const auto& metaData = metaIt->second; size_t metaDataPos = 0; auto metaDataReadFunc = [&](char* buf, size_t n) -> size_t { if (metaDataPos >= metaData.second || n == 0) { return 0; } size_t toCopy = std::min(metaDataPos + n, metaData.second) - metaDataPos; memcpy(buf, metaData.first + metaDataPos, toCopy); metaDataPos += toCopy; return toCopy; }; auto sectionReadFunc = [&](const std::string& ename) -> at::DataPtr { auto it = sections.find(ename); if (it == sections.end()) { TORCH_CHECK(false, "Couldn't find entity " + ename); } const auto& idat = it->second; auto dptr = at::getCPUAllocator()->allocate(idat.second); if (idat.second != 0) { memcpy(dptr.get(), idat.first, idat.second); } return dptr; }; // No need to pass typeResolver here, as it always processes string and // tensors only torch::jit::Unpickler unpickler( metaDataReadFunc, nullptr, nullptr, sectionReadFunc, {}); auto ival = unpickler.parse_ivalue(); for (auto&& t : ival.toTensorList()) { tensors.emplace_back(std::move(t)); } } return {std::move(payload), std::move(tensors)}; } void writeWrappedPayload( std::vector& originalPayload, std::vector& additionalPayload) { originalPayload.insert( originalPayload.end(), additionalPayload.begin(), additionalPayload.end()); // Add size of the additional payload int64_t indexToWrite = originalPayload.size(); originalPayload.resize(originalPayload.size() + sizeof(int64_t)); const int64_t additionalPayloadSize = additionalPayload.size(); torch::utils::THP_encodeInt64Buffer( reinterpret_cast(originalPayload.data()) + indexToWrite, &additionalPayloadSize, torch::utils::THPByteOrder::THP_BIG_ENDIAN, 1); } std::vector readWrappedPayload( std::vector& payload, const rpc::Message& message) { // Read the additional payload remove it from the payload. // NOLINTNEXTLINE(cppcoreguidelines-init-variables) int64_t additionalPayloadSize; TORCH_INTERNAL_ASSERT(payload.size() >= sizeof(int64_t)); size_t indexToRead = payload.size() - sizeof(int64_t); torch::utils::THP_decodeInt64Buffer( &additionalPayloadSize, reinterpret_cast(payload.data()) + indexToRead, torch::utils::THPByteOrder::THP_BIG_ENDIAN, 1); payload.resize(indexToRead); TORCH_INTERNAL_ASSERT( additionalPayloadSize > 0 && static_cast(payload.size()) > additionalPayloadSize, "Wrong payload sizes: payload.size() is ", payload.size(), " but additional payload size is ", additionalPayloadSize); auto wrappedPayloadBegin = static_cast(message.payload().data()) + payload.size() - additionalPayloadSize; std::vector tensorTable; IValue tuple = jit::unpickle( wrappedPayloadBegin, additionalPayloadSize, *rpc::RpcAgent::getCurrentRpcAgent()->getTypeResolver(), tensorTable); std::vector tupleElements = tuple.toTupleRef().elements().vec(); payload.resize(payload.size() - additionalPayloadSize); return tupleElements; } void populateRemoteProfiledEvents( std::vector& profiledEvents, const ProfilerConfig& profilingConfig, const std::vector>& eventLists) { // Gather all events into a vector for (auto& l : eventLists) { for (auto& e : l) { profiledEvents.push_back(e); } } // find __start_profile event bool cudaProfilingEnabled = profilingConfig.state == ProfilerState::CUDA; const LegacyEvent* profilerStart = nullptr; for (auto& e : profiledEvents) { if (std::string(e.name()) == "__start_profile") { profilerStart = &e; break; } } // We should always find __start_profile. TORCH_CHECK( profilerStart != nullptr, "Expected to find __start_profile event."); if (cudaProfilingEnabled) { // Deserialized events don't have the corresponding CUDA events, making it // impossible to use cudaEventElapsedTime the receiving end. To avoid this, // find all push/pop pairs of CUDA events and set the corresponding CUDA // time to zero for the push event and to the elapsed time for the pop // event, to be used later for the elapsed CUDA time computation. std::unordered_map startEvents; for (auto& e : profiledEvents) { if (e.hasCuda()) { if (e.kind() == EventKind::PushRange) { startEvents[e.handle()] = &e; } } } for (auto& e : profiledEvents) { if (e.hasCuda()) { if (e.kind() == EventKind::PopRange) { auto it = startEvents.find(e.handle()); if (it != startEvents.end()) { e.setCudaUs(it->second->cudaElapsedUs(e)); } else { TORCH_WARN("Found a pop event without a corresponding push event"); e.setCudaUs(0); } } else { e.setCudaUs(0); } } } } } } // namespace torch::distributed::rpc