#include #include #include namespace c10d { Work::Work( int rank, OpType opType, const char* profilingTitle, const std::optional>& inputTensors) : rank_(rank), opType_(opType) { if (profilingTitle != nullptr) { auto recordingFunction = std::make_shared(at::RecordScope::USER_SCOPE); if (recordingFunction->isActive()) { // Work events follow a future like pattern and can potentially be marked // as complete by different threads, so explicitly set as async event. recordingFunction->_setAsync(); // Passing input tensor to recordFunction allows for shape information in // profiling output. std::vector inputs; if (inputTensors) { inputs.reserve(inputTensors->size()); for (const auto& tensor : *inputTensors) { inputs.emplace_back(tensor); } } recordingFunction->before( profilingTitle, c10::ArrayRef(inputs.data(), inputs.size())); std::function end_handler = [recordingFunction]() { recordingFunction->end(); }; recordFunctionEndCallback_ = at::wrapPropagateTLSState(end_handler); } } } OpType Work::retrieveOpType() const { return opType_; } Work::~Work() = default; bool Work::isCompleted() { std::lock_guard lock(mutex_); return completed_; } bool Work::isSuccess() const { std::lock_guard lock(mutex_); return !exception_; } std::exception_ptr Work::exception() const { std::lock_guard lock(mutex_); return exception_; } int Work::sourceRank() const { TORCH_CHECK( false, "sourceRank() may only be called on work objects " "that correspond to a recv or recv-from-any call."); } std::vector Work::result() { TORCH_CHECK(false, "result() not implemented."); } void Work::synchronize() {} bool Work::wait(std::chrono::milliseconds timeout) { std::unique_lock lock(mutex_); if (timeout == kNoTimeout) { // This waits without a timeout. cv_.wait(lock, [&] { return completed_; }); } else { // Waits for the user-provided timeout. cv_.wait_for(lock, timeout, [&] { return completed_; }); if (!completed_) { // Throw exception if the wait operation timed out and the work was not // completed. TORCH_CHECK(false, "Operation timed out!"); } } if (exception_) { std::rethrow_exception(exception_); } synchronize(); // Always return true, because abort API is not implemented. return true; } void Work::abort() { TORCH_CHECK(false, "Work::abort not implemented."); } c10::intrusive_ptr Work::getFuture() { TORCH_CHECK(false, "Work::getFuture not implemented.") } void Work::finish(std::exception_ptr exception) { std::unique_lock lock(mutex_); completed_ = true; exception_ = std::move(exception); if (recordFunctionEndCallback_) { recordFunctionEndCallback_(); recordFunctionEndCallback_ = nullptr; } lock.unlock(); cv_.notify_all(); } void Work::finishAndThrow(std::exception_ptr exception) { std::unique_lock lock(mutex_); completed_ = true; exception_ = std::move(exception); if (recordFunctionEndCallback_) { recordFunctionEndCallback_(); recordFunctionEndCallback_ = nullptr; } if (exception_) { std::rethrow_exception(exception_); } } float Work::getDuration() const { TORCH_CHECK(false, "This Backend doesn't support getDuration."); } uint64_t Work::getSequencenumber() const { TORCH_CHECK(false, "This Backend doesn't support getSequencenumber."); } class FutureWrappingWork : public Work { public: FutureWrappingWork(c10::intrusive_ptr fut) : Work(), _fut(std::move(fut)) {} ~FutureWrappingWork() override = default; bool isCompleted() override { return _fut->completed(); } bool isSuccess() const override { return _fut->hasValue(); } std::exception_ptr exception() const override { return _fut->exception_ptr(); } int sourceRank() const override { TORCH_CHECK(false, "FutureWrappingWork::sourceRank() not implemented"); } std::vector result() override { return _fut->value().toPyObjectHolder()->extractTensors(); } bool wait(std::chrono::milliseconds timeout) override { // FIXME TORCH_CHECK( timeout == kNoTimeout, "FutureWrappingWork::wait() with finite timeout not implemented"); _fut->wait(); return true; } void abort() override { TORCH_CHECK(false, "FutureWrappingWork::abort() not implemented"); } c10::intrusive_ptr getFuture() override { return _fut; } private: c10::intrusive_ptr _fut; }; c10::intrusive_ptr Work::create_from_future( const c10::intrusive_ptr& future) { return c10::make_intrusive(future); } } // namespace c10d