#include #include #include #include #include #include #include #include #include #include namespace torch { namespace distributed { namespace autograd { using torch::autograd::AccumulateGrad; using torch::autograd::edge_list; using torch::autograd::Engine; using torch::autograd::GraphRoot; using torch::autograd::GraphTask; using torch::autograd::GraphTaskGuard; using torch::autograd::InputBuffer; using torch::autograd::Node; using torch::autograd::NodeTask; using torch::autograd::ReadyQueue; using torch::autograd::validate_outputs; using torch::autograd::variable_list; static constexpr const char* kNumBackwardPasses = "num_current_backward_passes"; static constexpr const char* kNumAutogradContexts = "num_autograd_contexts"; // This hook does 3 things: // 1. Call pre hooks of the original AccumulateGrad to modify the input grad. // 2. Accumurate the guard to RPC context. // 3. Call post hooks of the original AccumulateGrad. class DistAccumulateGradCaptureHook : public GraphTask::ExecInfo::Capture::GradCaptureHook { public: DistAccumulateGradCaptureHook( std::shared_ptr accumulateGrad, ContextPtr autogradContext) : accumulateGrad_(std::move(accumulateGrad)), autogradContext_(std::move(autogradContext)) {} at::Tensor operator()(const at::Tensor& grad) override { ThreadLocalDistAutogradContext contextGuard{ContextPtr(autogradContext_)}; variable_list inputGrads = {grad}; // It's intended that pre/post hooks are still called even if the grad is // undefined here. for (const auto& hook : accumulateGrad_->pre_hooks()) { inputGrads = (*hook)(inputGrads); } // It is possible that the grad is not defined since a separate // invocation of the autograd engine on the same node might actually // compute this gradient. if (inputGrads[0].defined()) { // There are 3 internal references to 'inputGrads[0]' at this moment: // 1. 'inputGrads[0]' in this function. // 2. 'graph_task->captured_vars_' on the callsite in the local engine. // 3. 'InputBuffer& inputs' on the callsite as the inputs of the // function node. autogradContext_->accumulateGrad( accumulateGrad_->variable, inputGrads[0], 3 /* num_expected_refs */); } const variable_list kEmptyOutput; for (const auto& hook : accumulateGrad_->post_hooks()) { (*hook)(kEmptyOutput, inputGrads); } return inputGrads[0]; } private: std::shared_ptr accumulateGrad_; ContextPtr autogradContext_; }; void DistEngine::globalCpuThread( const std::shared_ptr& ready_queue) { c10::setThreadName("pt_dist_engine"); while (true) { NodeTask task = ready_queue->pop(); if (task.isShutdownTask_) { // Need to shutdown this thread. C10_LOG_API_USAGE_ONCE("torch.autograd.thread_shutdown"); break; } auto graphTask = task.base_.lock(); if (graphTask == nullptr) { // GraphTask has expired, ignore and continue processing. continue; } // Launch the execution on a JIT thread. at::launch([this, graphTask, graphRoot = task.fn_, variables = InputBuffer::variables(std::move(task.inputs_))]() mutable { InputBuffer inputs(variables.size()); for (const auto i : c10::irange(variables.size())) { inputs.add(i, std::move(variables[i]), std::nullopt, std::nullopt); } execute_graph_task_until_ready_queue_empty( /*node_task*/ NodeTask(graphTask, graphRoot, std::move(inputs)), /*incrementOutstandingTasks*/ false); }); } } DistEngine::DistEngine() : initializedContextIds_(), engine_(Engine::get_default_engine()), global_cpu_ready_queue_(std::make_shared()), global_cpu_thread_( &DistEngine::globalCpuThread, this, global_cpu_ready_queue_) { // Note [GPU to CPU continuations] // ~~~~~~~~~~~~~~~~~~~~~~~~~~ // Initialize a single CPU thread to execute continuations from GPU // tasks. The multithreaded structure for the distributed engine works // well only for CPU tasks. If we have an order of tasks like // CPU->GPU->CPU, distributed autograd has no thread to execute the last // CPU task on. To fix this, we introduce a global CPU thread to handle // such situations and it will be responsible for executing these CPU // tasks. The CPU thread has its own ready_queue which is used as the // cpu_ready_queue for all GraphTasks for DistEngine. This ensures all GPU // to CPU continuations are enqueued on this thread. The global CPU thread // simply dequeues tasks from the global queue and calls // "execute_graph_task_until_ready_queue_empty" on a JIT thread to execute the // appropriate task. global_cpu_thread_.detach(); } DistEngine::~DistEngine() { // Ensure we shutdown the CPU thread. TORCH_ASSERT_NO_GIL_WITHOUT_PYTHON_DEP(); global_cpu_ready_queue_->pushShutdownTask(); global_cpu_thread_.join(); } DistEngine& DistEngine::getInstance() { // Leaky singleton to avoid module destructor race. static DistEngine* engine = new DistEngine(); return *engine; } void DistEngine::validateRootsAndRetrieveEdges( const variable_list& roots, edge_list& rootEdges, variable_list& grads) { TORCH_CHECK(!roots.empty(), "No tensors provided for gradient computation."); TORCH_INTERNAL_ASSERT(rootEdges.empty()); TORCH_INTERNAL_ASSERT(grads.empty()); // Verify roots are all scalar and require gradients. for (const auto& root : roots) { TORCH_CHECK(root.requires_grad(), "requires_grad not set on root"); TORCH_CHECK( root.numel() == 1, root.name(), " is not a scalar, all roots need to be scalar"); TORCH_CHECK( root.grad_fn(), root.name(), " does not have a valid gradient function."); // Compute the root edges and generate the appropriate gradients. rootEdges.push_back(torch::autograd::impl::gradient_edge(root)); grads.push_back(at::ones_like(root, LEGACY_CONTIGUOUS_MEMORY_FORMAT)); } // Validate rootEdges and grads. validate_outputs( rootEdges, grads, [](const std::string& msg) { return msg; }); } void DistEngine::computeDependencies( const ContextPtr& autogradContext, const edge_list& rootEdges, const variable_list& grads, const std::shared_ptr& graphRoot, edge_list& outputEdges, bool retainGraph) { TORCH_INTERNAL_ASSERT(graphRoot, "graphRoot is null!"); // Store root nodes so we can traverse through the graph later // e.g., for get_current_graph_task_execution_order c10::SmallVector temp_roots{rootEdges.size()}; for (const auto i : c10::irange(rootEdges.size())) { temp_roots[i] = rootEdges[i].function.get(); } // Build the graph task and graph root. // NOTE: we don't need to build and pass a cpu_ready_queue to GraphTask // as we use execute_graph_task_until_ready_queue_empty, which will build // a separate ReadyQueue for each call. auto graphTask = std::make_shared( /* keep_graph */ retainGraph, /* create_graph */ false, /* depth */ 0, /* cpu_ready_queue */ global_cpu_ready_queue_, /* graph_roots */ temp_roots, /* exit_on_error */ true); // Run BFS to traverse the graph locally. The roots of the graph are // GraphRoot and all send functions for this autograd context. std::unordered_set seen; std::queue queue; queue.push(static_cast(graphRoot.get())); auto sendFunctions = autogradContext->sendFunctions(); // Add all the send functions to the queue as roots. for (const auto& mapEntry : sendFunctions) { // Increment 'outstanding_tasks_' for GraphTask for each send_function // since we want the local autograd engine to wait for all of them. graphTask->outstanding_tasks_++; queue.push(mapEntry.second.get()); } bool will_use_accelerator = false; edge_list recvBackwardEdges; // Traverse the graph. auto& dependencies = graphTask->dependencies_; while (!queue.empty()) { auto fn = queue.front(); queue.pop(); if (!will_use_accelerator) { will_use_accelerator = fn->stream().has_value(); } for (const auto& edge : fn->next_edges()) { if (auto nextFn = edge.function.get()) { dependencies[nextFn] += 1; const bool wasInserted = seen.insert(nextFn).second; if (wasInserted) { // Seeing this function for the first time. queue.push(nextFn); if (nextFn->next_edges().empty()) { TORCH_INTERNAL_ASSERT( dynamic_cast(nextFn) || dynamic_cast(nextFn)); // We have found a leaf node which should be either AccumulateGrad // or RecvRpcBackward. Record the function // to ensure we don't execute it and instead accumulate the grads on // the autograd context. These functions would be passed in as the // 'outputs' parameter of the vanilla autograd engine. // We don't accumulate any grads in the context for RecvRpcBackward. // RecvRpcBackward is added as an output edge to indicate it is a // leaf node and this helps in properly computing dependencies for // the local autograd graph. Putting RecvRpcBackward in // 'outputEdges' means that this function needs to be executed // (inline with our assumption for FAST mode that all send/recv // functions are valid in the backward pass), and as a result all of // its ancestors need to be executed as well. if (dynamic_cast(nextFn)) { recvBackwardEdges.emplace_back(edge); } outputEdges.emplace_back(edge); } } } } } if (will_use_accelerator) { // Collects current streams for CUDA/ROCM devices where this process has a // context, so graphTask::exec_post_processing can sync them with // leaf_streams. graphTask->stash_current_streams(); } // Now lets compute which functions need to be executed. The algorithm is as // follows: // 1. Create a dummy GraphRoot which points to all 'send' functions for this // context and the original graphRoot. Run 'init_to_execute' with the // outputEdges and the dummy GraphRoot. This ensures we mark // appropriate functions as needed if they are reachable only from a // specific 'send' function locally and not necessarily from the provided // roots. // 2. For all edges in 'outputEdges' which point to 'RecvRpcBackward', mark // those functions as needed for execution. The reason for this is that // 'init_to_execute', will mark these as not needed. But 'RecvRpcBackward' // is unique in the sense that we use it as a leaf node in graph to compute // needed execution accurately, but unlike AccumulateGrad, we do need to // execute this function. if (!outputEdges.empty()) { // Compute 'needed execution' starting from all 'send' functions and the // original graphRoot. edge_list edges; // Create some dummy edges (input_nr not important for init_to_execute). for (const auto& mapEntry : sendFunctions) { edges.emplace_back(mapEntry.second, 0); } // Add the original graphRoot as an edge. edges.emplace_back(graphRoot, 0); // Create a dummy GraphRoot and run init_to_execute with it. GraphRoot dummyRoot(edges, {}); graphTask->init_to_execute( dummyRoot, outputEdges, /*accumulate_grad=*/false, /*min_topo_nr=*/0); for (auto& mapEntry : graphTask->exec_info_) { auto& execInfo = mapEntry.second; if (!execInfo.captures_) { continue; } auto fn = mapEntry.first; // There may be nodes other than 'AccumulateGrad', e.g. RecvRPCBackward, // to be captured. if (auto accumulateGradFn = dynamic_cast(fn)) { for (auto& capture : *execInfo.captures_) { // Capture hooks are technically deprecated, but as an exception below // is the single and only instance of capture hooks usage that we // support. See NOTE [Deprecated capture hooks] for more context. capture.DO_NOT_USE_DEPRECATED_register_capture_hook( std::make_unique( std::dynamic_pointer_cast( accumulateGradFn->shared_from_this()), autogradContext)); } } } // Mark all 'RecvRPCBackward' as needing execution. for (const auto& recvBackwardEdge : recvBackwardEdges) { graphTask->exec_info_[recvBackwardEdge.function.get()].needed_ = true; } } // Set graph task owner in a single thread since concurrent access to // 'owner_' field is not permitted. graphTask->owner_ = torch::autograd::CPU_DEVICE; // Let autograd context take ownership of the GraphTask. autogradContext->setGraphTask(std::move(graphTask)); } void DistEngine::execute_graph_task_until_ready_queue_empty( NodeTask&& node_task, bool incrementOutstandingTasks) { engine_.initialize_device_threads_pool(); // Create a ready queue per call to traverse the graph_task from // root_to_execute This allow concurrent execution of the same GraphTask from // different threads std::shared_ptr cpu_ready_queue = std::make_shared(); auto graph_task = node_task.base_.lock(); if (graph_task == nullptr) { LOG(ERROR) << "GraphTask has expired for NodeTask: " << node_task.fn_->name() << ", skipping execution."; return; } cpu_ready_queue->push(std::move(node_task), incrementOutstandingTasks); torch::autograd::set_device(torch::autograd::CPU_DEVICE); while (!cpu_ready_queue->empty()) { std::shared_ptr local_graph_task; { // Scope this block of execution since NodeTask is not needed after this // block and can be deallocated (release any references to grad tensors // as part of inputs_) NodeTask task = cpu_ready_queue->pop(); if (!(local_graph_task = task.base_.lock())) { continue; } if (task.fn_ && !local_graph_task->has_error_.load()) { at::ThreadLocalStateGuard tls_guard(local_graph_task->thread_locals_); try { GraphTaskGuard guard(local_graph_task); engine_.evaluate_function( local_graph_task, task.fn_.get(), task.inputs_, cpu_ready_queue); } catch (std::exception& e) { engine_.thread_on_exception(local_graph_task, task.fn_, e); // break the loop in error so that we immediately stop the execution // of this GraphTask, mark it completed if necessary and return the // future with proper ErrorMessage break; } } } // Decrement the outstanding task. --local_graph_task->outstanding_tasks_; } // Check if we've completed execution. if (graph_task->completed()) { // We don't need to explicitly notify the owner thread, since // 'mark_as_completed_and_run_post_processing' would mark the Future as // completed and this would notify the owner thread that the task has been // completed. graph_task->mark_as_completed_and_run_post_processing(); } } c10::intrusive_ptr DistEngine:: runEngineAndAccumulateGradients( const ContextPtr& autogradContext, const std::shared_ptr& graphRoot, const edge_list& outputEdges, bool incrementOutstandingTasks) { // Cleanup previous state for outstanding RPCs. Outstanding RPCs could be // lingering if we're running backward multiple times and some of the // passes ran into errors. autogradContext->clearOutstandingRpcs(); auto graphTask = autogradContext->retrieveGraphTask(); at::launch([this, graphTask, graphRoot, incrementOutstandingTasks]() { execute_graph_task_until_ready_queue_empty( /*node_task*/ NodeTask(graphTask, graphRoot, InputBuffer(0)), /*incrementOutstandingTasks*/ incrementOutstandingTasks); }); // Use a reference here to avoid refcount bump on futureGrads. auto& futureGrads = graphTask->future_result_; // Build a future that waits for the callbacks to execute (since callbacks // execute after the original future is completed). This ensures we return a // future that waits for all gradient accumulation to finish. auto accumulateGradFuture = c10::make_intrusive(c10::NoneType::get()); futureGrads->addCallback([autogradContext, outputEdges, accumulateGradFuture]( c10::ivalue::Future& futureGrads) { if (futureGrads.hasError()) { // Don't accumulate gradients if we receive an error. // We must add the node information here since DistEngine::execute // waits on accumulateGradFuture and will throw an exception once we // set the error below. std::string errorMsg = c10::str( "Error on Node ", DistAutogradContainer::getInstance().getWorkerId(), ": ", futureGrads.tryRetrieveErrorMessage()); accumulateGradFuture->setError(std::make_exception_ptr( c10::ivalue::Future::FutureError(std::move(errorMsg)))); return; } try { const variable_list& grads = futureGrads.constValue().toTensorVector(); TORCH_INTERNAL_ASSERT(grads.size() == outputEdges.size()); accumulateGradFuture->markCompleted(c10::IValue()); } catch (std::exception& e) { accumulateGradFuture->setErrorIfNeeded(std::current_exception()); } }); return accumulateGradFuture; } c10::intrusive_ptr DistEngine::executeSendFunctionAsync( const ContextPtr& autogradContext, const std::shared_ptr& sendFunction, bool retainGraph) { // Typically the local autograd engine ensures stream synchronizations between // nodes in the graph. However, for distributed autograd the sendFunction // inputs might have been retrieved over the wire on a separate stream and the // sendFunction itself runs on a different stream. As a result, we need to // manually synchronize those two streams here. const auto& send_backward_stream = sendFunction->stream(); if (send_backward_stream) { for (const auto& grad : sendFunction->getGrads()) { const auto guard = c10::impl::VirtualGuardImpl{c10::DeviceType::CUDA}; const auto default_stream = guard.getStream(grad.device()); if (send_backward_stream != default_stream) { auto event = c10::Event{c10::DeviceType::CUDA}; event.record(default_stream); send_backward_stream->wait(event); } } } std::unique_lock lock(initializedContextIdsLock_); if (initializedContextIds_.find(autogradContext->contextId()) == initializedContextIds_.end()) { edge_list outputEdges; // Pass in a dummy graphRoot since all send functions are the roots. auto dummyRoot = std::make_shared(edge_list(), variable_list()); computeDependencies( autogradContext, {}, {}, dummyRoot, outputEdges, retainGraph); // Mark the autograd context id as initialized and unlock. initializedContextIds_.insert(autogradContext->contextId()); lock.unlock(); // Enqueue the current send function. auto graphTask = autogradContext->retrieveGraphTask(); // Run the autograd engine. auto accumulateGradFuture = runEngineAndAccumulateGradients( autogradContext, sendFunction, outputEdges, /*incrementOutstandingTasks=*/false); // Build the 'uber' future that waits for everything. auto callbackFuture = c10::make_intrusive(c10::NoneType::get()); accumulateGradFuture->addCallback( [autogradContext, callbackFuture](c10::ivalue::Future& accumulateGradFuture) { try { if (accumulateGradFuture.hasError()) { // Perform cleanup at the end of the backward pass (before we mark // the future as completed). DistEngine::getInstance().cleanupBackwardPass(autogradContext); // Skip any further processing on errors. callbackFuture->setError(accumulateGradFuture.exception_ptr()); return; } // Wait for all RPCs after the autograd engine is done. auto rpcFuture = autogradContext->clearAndWaitForOutstandingRpcsAsync(); rpcFuture->addCallback([callbackFuture, autogradContext]( c10::ivalue::Future& rpcFuture) { try { // Perform cleanup at the end of the backward pass (before // we mark the future as completed). DistEngine::getInstance().cleanupBackwardPass(autogradContext); } catch (std::exception& e) { callbackFuture->setErrorIfNeeded(std::current_exception()); return; } // Finally mark the 'uber' future as completed. if (!rpcFuture.hasError()) { callbackFuture->markCompleted(c10::IValue()); } else { callbackFuture->setError(rpcFuture.exception_ptr()); } }); } catch (std::exception& e) { callbackFuture->setErrorIfNeeded(std::current_exception()); } }); // Return the future which waits for all async processing to be done. return callbackFuture; } else { lock.unlock(); auto graphTask = autogradContext->retrieveGraphTask(); at::launch([this, graphTask, sendFunction]() { execute_graph_task_until_ready_queue_empty( /*node_task*/ NodeTask(graphTask, sendFunction, InputBuffer(0)), /*incrementOutstandingTasks*/ false); }); auto fut = c10::make_intrusive(c10::NoneType::get()); fut->markCompleted(c10::IValue()); return fut; } } void DistEngine::execute( int64_t contextId, const variable_list& roots, bool retainGraph) { // Retrieve the context for the given context_id. This will throw if the // context_id is invalid. auto autogradContext = DistAutogradContainer::getInstance().retrieveContext(contextId); // Perform initial pre-processing. edge_list rootEdges; variable_list grads; validateRootsAndRetrieveEdges(roots, rootEdges, grads); std::shared_ptr graphRoot = std::make_shared(rootEdges, grads); edge_list outputEdges; // Compute dependencies locally, starting from all roots and all 'send' // functions. { std::lock_guard guard(initializedContextIdsLock_); // Context should not have been initialized already. TORCH_INTERNAL_ASSERT( initializedContextIds_.find(autogradContext->contextId()) == initializedContextIds_.end()); computeDependencies( autogradContext, rootEdges, grads, graphRoot, outputEdges, retainGraph); // Mark the autograd context id as initialized. initializedContextIds_.insert(autogradContext->contextId()); } BackwardPassCleanupGuard guard(autogradContext); // This needs to be blocking and as a result we wait for the future to // complete. runEngineAndAccumulateGradients(autogradContext, graphRoot, outputEdges) ->waitAndThrow(); // Wait for all of the outstanding rpcs to complete. autogradContext->clearAndWaitForOutstandingRpcsAsync()->waitAndThrow(); } void DistEngine::cleanupBackwardPass(const ContextPtr& autogradContext) { // Validate only the GraphTask is holding a reference to the Future // which holds gradients for the backward pass. This ensures that // after 'resetGraphTask' is called below, there are no remaining // references left to the gradients for the backward pass. // // This ensures our 'use_count' checks in // AccumulateGrad::accumulateGrad are correct and we're // not leaking any references to the gradients anywhere else. const auto& futureGrads = autogradContext->retrieveGraphTask()->future_result_; TORCH_INTERNAL_ASSERT(futureGrads.use_count() == 1); // Reset the graph task once we're done with all processing. autogradContext->resetGraphTask(); // Clear any outstanding rpcs. autogradContext->clearOutstandingRpcs(); // Clear the context id once we're done with the autograd engine // processing. std::lock_guard guard(initializedContextIdsLock_); initializedContextIds_.erase(autogradContext->contextId()); } size_t DistEngine::numBackwardPasses() const { std::lock_guard guard(initializedContextIdsLock_); return initializedContextIds_.size(); } std::unordered_map DistEngine::getDebugInfo() const { std::unordered_map debugInfo; debugInfo[kNumBackwardPasses] = numBackwardPasses(); debugInfo[kNumAutogradContexts] = DistAutogradContainer::getInstance().numAutogradContexts(); return debugInfo; } } // namespace autograd } // namespace distributed } // namespace torch