#ifdef USE_C10D_UCC #include #include #include #include #include #include #include #include namespace c10d { namespace { const std::map ucc_mtype_map = { {c10::kCPU, UCC_MEMORY_TYPE_HOST}, {c10::kCUDA, UCC_MEMORY_TYPE_CUDA}, }; ucc_memory_type_t to_ucc_memType(c10::DeviceType _c10_type) { if (ucc_mtype_map.find(_c10_type) != ucc_mtype_map.end()) return ucc_mtype_map.at(_c10_type); else return UCC_MEMORY_TYPE_UNKNOWN; } const std::map ucc_dtype_map = { {at::kByte, UCC_DT_UINT8}, {at::kChar, UCC_DT_INT8}, {at::kHalf, UCC_DT_FLOAT16}, {at::kBFloat16, UCC_DT_BFLOAT16}, {at::kDouble, UCC_DT_FLOAT64}, {at::kFloat, UCC_DT_FLOAT32}, {at::kInt, UCC_DT_INT32}, {at::kLong, UCC_DT_INT64}, {at::kBool, UCC_DT_UINT8}, }; ucc_datatype_t to_ucc_dType(at::Tensor _tensor) { if (_tensor.scalar_type() == at::kBool && _tensor.element_size() != 1) { TORCH_CHECK( false, "Size of Boolean type larger than 1 is not supported in UCC"); } try { return ucc_dtype_map.at(_tensor.scalar_type()); } catch (const std::out_of_range&) { TORCH_CHECK(false, "Not supported data type for UCC"); } } const std::map ucc_op_map = { {ReduceOp::SUM, UCC_OP_SUM}, {ReduceOp::PRODUCT, UCC_OP_PROD}, {ReduceOp::MIN, UCC_OP_MIN}, {ReduceOp::MAX, UCC_OP_MAX}, {ReduceOp::BAND, UCC_OP_BAND}, {ReduceOp::BOR, UCC_OP_BOR}, {ReduceOp::BXOR, UCC_OP_BXOR}, {ReduceOp::AVG, UCC_OP_AVG}, }; ucc_reduction_op_t to_ucc_reduceOp( const ReduceOp _op, const at::ScalarType _dt) { if (_dt == at::kBool) { if (_op == ReduceOp::SUM) { // bitwise or return UCC_OP_MAX; } else if (_op == ReduceOp::PRODUCT) { // bitwise and return UCC_OP_MIN; } else if (_op == ReduceOp::AVG) { TORCH_CHECK(false, "Cannot use ReduceOp.AVG with boolean inputs"); } } try { return ucc_op_map.at(_op); } catch (const std::out_of_range&) { TORCH_CHECK(false, "Not supported ReduceOp for UCC"); } } struct torch_ucc_config_t { c10::once_flag flag; std::array blocking_wait; bool enable_comms_logger; bool use_future; // Sharing UCC communicator among multiple PGs to save resource. bool shared_comm; // Using allgatherv to achieve allgather, without flattening the list of // (potentially non-contiguous) tensors. bool use_allgatherv; bool enable_health_check; } torch_ucc_config; std::unordered_map torch_ucc_envs_map = { // TORCH_UCC_BLOCKING_WAIT allowed syntax: // - TORCH_UCC_BLOCKING_WAIT=none --> blocking wait completely disabled // - TORCH_UCC_BLOCKING_WAIT=all --> blocking wait completely enabled // - TORCH_UCC_BLOCKING_WAIT=allreduce,send,recv --> blocking wait enabled // on selected operations // Supported operations: // [allgather,allgather_base,allreduce,alltoall,broadcast, // gather,reduce,reduce_scatter,scatter,send,recv] {"TORCH_UCC_BLOCKING_WAIT", "none"}, {"TORCH_UCC_USE_FUTURE", "1"}, {"TORCH_UCC_PROFILING_ENABLE", "0"}, {"TORCH_UCC_SHARED_COMM", "1"}, {"TORCH_UCC_USE_ALLGATHERV", "0"}, {"TORCH_UCC_ENABLE_HEALTH_CHECK", "0"}, {"TORCH_UCC_ENABLE_COMMS_LOGGER", "0"}, }; std::vector parse_blocking_wait(std::string op_list_string) { const static std::unordered_map str2op = { {"allgather", OpType::ALLGATHER}, {"allgather_base", OpType::_ALLGATHER_BASE}, {"allreduce", OpType::ALLREDUCE}, {"alltoall_base", OpType::ALLTOALL_BASE}, {"broadcast", OpType::BROADCAST}, {"gather", OpType::GATHER}, {"reduce", OpType::REDUCE}, {"reduce_scatter", OpType::REDUCE_SCATTER}, {"scatter", OpType::SCATTER}, {"send", OpType::SEND}, {"recv", OpType::RECV}, }; auto op_list = parse_list(op_list_string); if (op_list == std::vector{"none"}) { return {}; } std::vector result; if (op_list == std::vector{"all"}) { for (auto entry : str2op) { result.push_back(entry.second); } } else { for (auto op_string : op_list) { result.push_back(str2op.at(op_string)); } } return result; } } // namespace void read_config() { // default configuration torch_ucc_config.blocking_wait.fill(false); torch_ucc_config.use_future = true; torch_ucc_config.shared_comm = false; torch_ucc_config.use_allgatherv = false; torch_ucc_config.enable_health_check = false; torch_ucc_config.enable_comms_logger = false; // read all torch_ucc env. variables and update the map char* env; for (auto& torch_ucc_env : torch_ucc_envs_map) { env = std::getenv(torch_ucc_env.first.c_str()); if (env) { torch_ucc_envs_map[torch_ucc_env.first] = std::string(env); } } auto blocking_wait_str = torch_ucc_envs_map.at("TORCH_UCC_BLOCKING_WAIT"); for (auto op : parse_blocking_wait(blocking_wait_str)) { torch_ucc_config.blocking_wait[(std::uint8_t)op] = true; } // barrier is always blocking torch_ucc_config.blocking_wait[(std::uint8_t)OpType::BARRIER] = true; torch_ucc_config.use_future = std::stoi(torch_ucc_envs_map.at("TORCH_UCC_USE_FUTURE")); torch_ucc_config.shared_comm = std::stoi(torch_ucc_envs_map.at("TORCH_UCC_SHARED_COMM")); torch_ucc_config.use_allgatherv = std::stoi(torch_ucc_envs_map.at("TORCH_UCC_USE_ALLGATHERV")); torch_ucc_config.enable_health_check = std::stoi(torch_ucc_envs_map.at("TORCH_UCC_ENABLE_HEALTH_CHECK")); torch_ucc_config.enable_comms_logger = std::stoi(torch_ucc_envs_map.at("TORCH_UCC_ENABLE_COMMS_LOGGER")); } void check_device(c10::Device dev1, c10::Device dev2) { if (dev1.is_cuda() && dev2.is_cuda() && dev1 != dev2) { throw std::invalid_argument("ProcessGroupUCC multidevice is not supported"); } } void check_tensor(const std::vector& tensors) { if (tensors.size() != 1) { throw std::invalid_argument( "ProcessGroupUCC takes 1 tensor. Got " + std::to_string(tensors.size()) + ". "); } if (!tensors[0].is_contiguous()) { throw std::invalid_argument( "ProcessGroupUCC input tensor has to be contiguous"); } if (tensors[0].is_sparse()) { throw std::invalid_argument("ProcessGroupUCC input tensor has to be dense"); } // TODO: check cuda case } ProcessGroupUCC::WorkUCC::~WorkUCC() { #ifdef USE_CUDA if (fence && ep) { std::lock_guard lock(ep->event_pool_mutex); ep->event_pool.push(std::move(fence)); } #endif } void ProcessGroupUCC::WorkUCC::setException() { if (exception() || !entry_) { return; } exception_ = entry_->eptr_; } void ProcessGroupUCC::WorkUCC::setAndThrowException() { setException(); if (exception()) { std::rethrow_exception(exception()); } } bool ProcessGroupUCC::WorkUCC::isCompleted() { if (!entry_) { return true; } setException(); // status_ <= 0 to avoid listing all possible status codes. The main thread // needs to be unblocked when UCC (in progress thread) returns success (== 0) // or any error code (< 0). return exception() || entry_->status_ <= 0; } bool ProcessGroupUCC::WorkUCC::isSuccess() const { if (!entry_) { return true; } return !exception() && entry_->status_ == 0; } bool ProcessGroupUCC::WorkUCC::wait(std::chrono::milliseconds /* unused */) { if (torch_ucc_config.enable_comms_logger && logger_) { logger_->trace_generator->recordComms("wait", (uintptr_t)this, rank_); } #ifdef USE_CUDA if (fence && !torch_ucc_config.blocking_wait[(int)opType_]) { // block user stream setAndThrowException(); fence->block(at::cuda::getCurrentCUDAStream()); return true; } #endif // wait for complete. For blocking case, the main thread will be blocked in // this loop until the progress thread changes the status of this request. // If timeout occurs, UCC will return UCC_ERR_TIMEOUT as the status. The // main thread will throw out the exception then. There is no "abort" // function in UCC currently. while (!isCompleted()) ; setAndThrowException(); // manually call profiling end callbacks if they are set, // since progress thread does not own WorkUCC if (Work::recordFunctionEndCallback_) { Work::recordFunctionEndCallback_(); Work::recordFunctionEndCallback_ = nullptr; } return true; } c10::intrusive_ptr ProcessGroupUCC::WorkUCC::getFuture() { return future_; } int ProcessGroupUCC::WorkUCC::sourceRank() const { if (opType_ != OpType::RECV && opType_ != OpType::RECVANYSOURCE) { // Throw an error return Work::sourceRank(); } return sourceRank_; } std::vector ProcessGroupUCC::WorkUCC::result() { return *outputs_; } void ProcessGroupUCC::ProgressEntry::finalize(std::exception_ptr eptr) { ucc_status_t status = UCC_OK; if (request_ != nullptr) { status = request_->status; comm_->free_request(request_); } if (eptr) { eptr_ = eptr; } else { status_ = status; } if (future_) { if (eptr) { future_->setError(eptr); } else { future_->markCompleted( c10::IValue(data ? data->dst : std::vector())); } } } Comm::Comm( const c10::intrusive_ptr& logger_, std::shared_ptr oob_, c10::Device dev, bool is_health_check) : logger(logger_), oob(oob_), ucc_comm(oob, logger), finalize_phase( is_health_check ? TORCH_UCC_HEALTH_CHECK : TORCH_UCC_FINALIZE), cuda_device_index(TORCH_UCC_DEVICE_NOT_SET) { if (dev.is_cuda()) { cuda_device_index = dev.index(); } stop_progress_loop = false; collective_inprogress = false; progress_thread = std::thread(&Comm::progress_loop, this); #ifdef _GNU_SOURCE pthread_setname_np(progress_thread.native_handle(), "ucc-progress"); #endif } Comm::~Comm() { std::unique_lock lock(mutex); queue_consume_cv.wait( lock, [&] { return progress_queue.empty() && !collective_inprogress; }); stop_progress_loop = true; lock.unlock(); queue_produce_cv.notify_all(); progress_thread.join(); } std::shared_ptr Comm::get_comm( uint32_t& id, c10::Device dev, std::shared_ptr oob, const c10::intrusive_ptr& logger, bool is_health_check) { static std::mutex m; static std::weak_ptr comm; static uint32_t comm_id; std::lock_guard lock(m); id = comm_id; std::string group_id = "group_id"; if (is_health_check) { group_id = c10::str(dev.type()) + "/" + group_id; } std::vector remote_comm_id; oob->store->deleteKey(group_id + std::to_string(0)); if (oob->rank != 0) { std::vector val = std::vector( reinterpret_cast(&id), reinterpret_cast(&id) + sizeof(id)); oob->store->set(group_id + std::to_string(oob->rank), val); } else { for (int i = 1; i < oob->size; i++) { remote_comm_id = oob->store->get(group_id + std::to_string(i)); oob->store->deleteKey(group_id + std::to_string(i)); // Find the highest id. id = std::max(id, *(reinterpret_cast(remote_comm_id.data()))); } std::vector val = std::vector( reinterpret_cast(&id), reinterpret_cast(&id) + sizeof(id)); oob->store->set(group_id + std::to_string(oob->rank), val); } remote_comm_id = oob->store->get(group_id + std::to_string(0)); oob->comm_id = *(reinterpret_cast(remote_comm_id.data())); // Prepare comm_id (static variable) to the next id. comm_id = oob->comm_id + 1; if (torch_ucc_config.shared_comm) { std::shared_ptr shared_comm = comm.lock(); if (!shared_comm) { shared_comm = std::make_shared(logger, oob, dev, is_health_check); comm = shared_comm; } else { if (dev.is_cuda() && !is_health_check) { if ((shared_comm->cuda_device_index != TORCH_UCC_DEVICE_NOT_SET) && (shared_comm->cuda_device_index != dev.index())) { TORCH_UCC_LOG_ERROR( is_health_check ? TORCH_UCC_HEALTH_CHECK : TORCH_UCC_INIT, "ucc communicator was initialized with different cuda device," "multi device is not supported"); throw std::invalid_argument(ucc_status_string(UCC_ERR_NOT_SUPPORTED)); } shared_comm->cuda_device_index = dev.index(); } } return shared_comm; } else { return std::make_shared(logger, oob, dev, is_health_check); } } void Comm::ucc_create_team( ucc_team_h& team, std::shared_ptr oob) { ucc_status_t st; ucc_team_params_t team_params; team_params.mask = UCC_TEAM_PARAM_FIELD_EP | UCC_TEAM_PARAM_FIELD_EP_RANGE | UCC_TEAM_PARAM_FIELD_OOB; team_params.oob.allgather = oob_allgather; team_params.oob.req_test = oob_allgather_test; team_params.oob.req_free = oob_allgather_free; team_params.oob.coll_info = oob.get(); team_params.oob.n_oob_eps = oob->size; team_params.oob.oob_ep = oob->rank; team_params.ep = oob->rank; team_params.ep_range = UCC_COLLECTIVE_EP_RANGE_CONTIG; TORCH_UCC_CHECK( ucc_team_create_post(&ucc_comm.context, 1, &team_params, &team), "failed to post team create"); do { st = ucc_team_create_test(team); ucc_context_progress(ucc_comm.context); } while (st == UCC_INPROGRESS); TORCH_UCC_CHECK(st, "failed to create UCC team"); } void Comm::ucc_destroy_team(ucc_team_h& team) { std::unique_lock lock(mutex); queue_consume_cv.wait( lock, [&] { return progress_queue.empty() && !collective_inprogress; }); ucc_status_t status; while (UCC_INPROGRESS == (status = ucc_team_destroy(team))) { if (UCC_OK != status) { TORCH_UCC_LOG_ERROR( finalize_phase, c10::str("ucc team destroy error: ", ucc_status_string(status))); break; } } lock.unlock(); } void Comm::enqueue_collective( std::unique_ptr data, c10::intrusive_ptr work, ucc_coll_args_t& coll, ucc_team_h team) { ucc_coll_req_h request; TORCH_UCC_CHECK( ucc_collective_init(&coll, &request, team), "failed to init collective"); TORCH_UCC_CHECK_REQUEST( request, ucc_collective_post(request), "failed to post collective"); auto entry = std::make_shared(&ucc_comm, request); entry->data = std::move(data); entry->future_ = work->getFuture(); work->entry_ = entry; std::unique_lock lock(mutex); progress_queue.push_back(entry); lock.unlock(); queue_produce_cv.notify_one(); } #ifdef USE_CUDA void Comm::enqueue_cuda_collective( std::unique_ptr data, c10::intrusive_ptr work, ucc_coll_args_t& coll, ucc_team_h team, ucc_ee_h ee) { ucc_coll_req_h request; TORCH_UCC_CHECK( ucc_collective_init(&coll, &request, team), "failed to init cuda collective"); ucc_ev_t comp_ev, *post_ev; comp_ev.ev_type = UCC_EVENT_COMPUTE_COMPLETE; comp_ev.ev_context = nullptr; comp_ev.ev_context_size = 0; comp_ev.req = request; TORCH_UCC_CHECK_REQUEST( request, ucc_collective_triggered_post(ee, &comp_ev), "failed to post triggered collective"); ucc_status_t st = ucc_ee_get_event(ee, &post_ev); TORCH_CHECK(st == UCC_OK && post_ev->ev_type == UCC_EVENT_COLLECTIVE_POST); ucc_ee_ack_event(ee, post_ev); auto entry = std::make_shared(&ucc_comm, request); entry->data = std::move(data); work->entry_ = entry; std::unique_lock lock(mutex); progress_queue.push_back(entry); lock.unlock(); queue_produce_cv.notify_one(); } #endif void Comm::progress_loop() { std::unique_lock lock(mutex); #ifdef USE_CUDA bool device_set = false; #endif while (!stop_progress_loop) { if (progress_queue.empty()) { queue_produce_cv.wait(lock); continue; } collective_inprogress = true; auto work = progress_queue.front(); progress_queue.pop_front(); lock.unlock(); #ifdef USE_CUDA if ((!device_set) && (cuda_device_index != TORCH_UCC_DEVICE_NOT_SET)) { c10::cuda::set_device(cuda_device_index); CUcontext pctx = nullptr; at::globalContext().getNVRTC().cuCtxGetCurrent(&pctx); if (C10_UNLIKELY(!pctx)) { at::globalContext().getNVRTC().cuDevicePrimaryCtxRetain( &pctx, cuda_device_index); at::globalContext().getNVRTC().cuCtxSetCurrent(pctx); } device_set = true; } #endif std::exception_ptr eptr; try { while (work->request_->status > 0) { ucc_comm.progress(); } if (work->request_->status < 0) { eptr = std::make_exception_ptr( std::runtime_error(ucc_status_string(work->request_->status))); std::string err_log = c10::str( "Failed to progress communication", // TODO: report exact op type or // id? ucc_status_string(work->request_->status)); TORCH_UCC_LOG_ERROR(TORCH_UCC_COLL_PROGRESS, err_log); } } catch (...) { eptr = std::current_exception(); } work->finalize(eptr); work = nullptr; collective_inprogress = false; queue_consume_cv.notify_one(); lock.lock(); } } ProcessGroupUCC::ProcessGroupUCC( const c10::intrusive_ptr& store, int rank, int size, std::chrono::duration timeout) : Backend(rank, size), timeout_(timeout) { c10::call_once(torch_ucc_config.flag, read_config); oob = std::make_shared(); oob->rank = rank; oob->size = size; oob->store = store; comm = nullptr; cuda_ee = nullptr; static uint32_t id = 0; uint32_t pg_id = id++; logger = c10::make_intrusive( c10::str("[Rank ", rank_, "]", "[ProcessGroupUCC-", pg_id, "]"), TORCH_UCC_INIT); TORCH_UCC_LOG_INFO( TORCH_UCC_INIT, c10::str( "Created ProcessGroupUCC with ", size, " ranks, with timeout ", timeout_.count(), " secs")); std::string envs = ""; for (auto& torch_ucc_env : torch_ucc_envs_map) { envs += ("\n\t" + torch_ucc_env.first + "=" + torch_ucc_env.second); } TORCH_UCC_LOG_INFO( TORCH_UCC_INIT, c10::str( "Successfully read and set ProcessGroupUCC env. variables as followings", envs)); if (torch_ucc_config.enable_health_check) { // Perform health check by initializing dummy communicators and destroying // them. This will help indicate any UCC/UCX-related issues prior to the // first collective. Run it in a separate thread and wait on CV to handle // timeouts so that if there are hangs, the main thread can still run // correctly. runHealthCheck(); } if (torch_ucc_config.enable_comms_logger) { logger->initCommsTracer(); } } ProcessGroupUCC::~ProcessGroupUCC() { if (torch_ucc_config.enable_comms_logger) { logger->flushComms(this->getRank(), this->getSize()); } if (comm) { logger->setPhase(TORCH_UCC_FINALIZE); comm->ucc_destroy_team(team); TORCH_UCC_LOG_INFO( TORCH_UCC_FINALIZE, "Successfully destroyed UCC library"); try { if (cuda_ee) { ucc_ee_destroy(cuda_ee); ucc_ee_destroy(cuda_ee_p2p[0]); ucc_ee_destroy(cuda_ee_p2p[1]); } } catch (std::exception& ex) { TORCH_UCC_LOG_INFO( TORCH_UCC_FINALIZE, c10::str( "(~ProcessGroupUCC) Caught error in Store Operation .. ", "[", ex.what(), "]")); } comm = nullptr; } } #ifdef USE_CUDA // Return CUDA device with ordinal given by input rank. c10::Device getCUDADeviceForRank(int rank) { TORCH_CHECK(rank >= 0, "Invalid rank ", rank); auto numGPUs = at::cuda::getNumGPUs(); auto deviceIdx = static_cast(rank % numGPUs); return c10::Device(c10::DeviceType::CUDA, deviceIdx); } #endif void ProcessGroupUCC::runHealthCheck() { // Run health check in a separate thread and wait on CV to handle timeouts. // This design allows us to handle hangs. // When size_ is 1, there is no need to do any communication at all. if (size_ == 1) return; struct HealthCheckData { std::mutex healthCheckMutex; std::condition_variable healthCheckCv; bool uccHealthCheckSuccess = false; std::exception_ptr healthCheckException; } healthCheckData; auto t = std::thread([&healthCheckData, this]() { std::list devices{c10::kCPU}; #ifdef USE_CUDA c10::cuda::OptionalCUDAGuard gpuGuard; if (at::cuda::is_available()) { devices.emplace_front(getCUDADeviceForRank(rank_)); } #endif for (auto device : devices) { bool is_last_device = (device == devices.back()); try { auto oob = std::make_shared(); oob->rank = this->oob->rank; oob->size = this->oob->size; oob->store = this->oob->store; ucc_team_h team = nullptr; uint32_t comm_id; #ifdef USE_CUDA if (device.is_cuda()) { gpuGuard.set_index(device.index()); } #endif auto comm = Comm::get_comm(comm_id, device, oob, logger, true); comm->ucc_create_team(team, oob); comm->ucc_destroy_team(team); TORCH_UCC_LOG_INFO( TORCH_UCC_HEALTH_CHECK, c10::str( "UCC library health check succeed for device ", c10::DeviceTypeName(device.type()))); // Mark ucc health check as complete. if (is_last_device) { std::lock_guard lk(healthCheckData.healthCheckMutex); healthCheckData.uccHealthCheckSuccess = true; } comm = nullptr; oob = nullptr; // Notify main thread the health check is complete. if (is_last_device) { healthCheckData.healthCheckCv.notify_one(); } } catch (const std::exception&) { // Populate exception ptr. healthCheckData.healthCheckException = std::current_exception(); // Unblock waiting main thread which will report exception. healthCheckData.healthCheckCv.notify_one(); } // Unknown exceptions will just cause the program to terminate. } }); // We don't need to join the thread, just need to verify health check via the // CV. Hence we detach the thread here. t.detach(); // NOLINT TORCH_UCC_LOG_INFO( TORCH_UCC_HEALTH_CHECK, c10::str( "will wait up to ", timeout_.count(), " msec for UCC health check to complete.")); std::unique_lock lock(healthCheckData.healthCheckMutex); healthCheckData.healthCheckCv.wait_for(lock, timeout_, [&healthCheckData]() { return healthCheckData.uccHealthCheckSuccess; }); if (healthCheckData.healthCheckException) { std::rethrow_exception(healthCheckData.healthCheckException); } // If there is no exception, the likely culprit is a timeout/hang TORCH_CHECK( healthCheckData.uccHealthCheckSuccess, "ProcessGroupUCC: Health check failure: Failed to initialize UCC on rank ", rank_); } void ProcessGroupUCC::set_timeout(ucc_coll_args_t& args) { args.mask |= UCC_COLL_ARGS_FIELD_FLAGS; args.flags |= UCC_COLL_ARGS_FLAG_TIMEOUT; args.timeout = timeout_.count(); } #ifdef USE_CUDA std::unique_ptr ProcessGroupUCC::getPooledEvent() { std::unique_ptr ev; std::lock_guard lock(ep.event_pool_mutex); if (ep.event_pool.empty()) { ev = std::make_unique(); } else { ev = std::move(ep.event_pool.front()); ep.event_pool.pop(); } return ev; } #endif template c10::intrusive_ptr ProcessGroupUCC::collective_post( OpType opType, PreProcess preproc, PostProcess postproc, ucc_coll_args_t& coll, std::unique_ptr data, c10::Device dev, std::vector& inputTensors, std::vector& outputTensors, const char* prof_title) { seq_++; set_timeout(coll); auto work = c10::make_intrusive( opType, seq_, prof_title, inputTensors, logger); if (opType == OpType::RECV) { work->sourceRank_ = coll.root; } RECORD_COMMS_TRACE( logger->trace_generator, work, opType, this->getRank(), this->getSize(), inputTensors, outputTensors); // Store references to outputs to be used by result work->outputs_ = std::make_shared>(outputTensors); switch (dev.type()) { case c10::DeviceType::CPU: { if (torch_ucc_config.use_future) { work->future_ = c10::make_intrusive( c10::ListType::create(c10::TensorType::get())); } preproc(); comm->enqueue_collective(std::move(data), work, coll, team); postproc(); return work; } #ifdef USE_CUDA case c10::DeviceType::CUDA: { auto cuda_ev = getPooledEvent(); at::cuda::CUDAStream* op_stream; ucc_ee_h* op_ee; if (opType == OpType::SEND) { op_stream = stream_p2p[0].get(); op_ee = &cuda_ee_p2p[0]; } else if (opType == OpType::RECV) { op_stream = stream_p2p[1].get(); op_ee = &cuda_ee_p2p[1]; } else { op_stream = stream.get(); op_ee = &cuda_ee; } cuda_ev->record(at::cuda::getCurrentCUDAStream(dev.index())); cuda_ev->block(*op_stream); at::cuda::CUDAStreamGuard guard(*op_stream); preproc(); comm->enqueue_cuda_collective(std::move(data), work, coll, team, *op_ee); postproc(); cuda_ev->record(*op_stream); work->fence = std::move(cuda_ev); work->ep = &ep; if (torch_ucc_config.use_future) { c10::cuda::CUDAMultiStreamGuard streamGuard(*op_stream); std::vector devList{dev}; work->future_ = c10::make_intrusive( c10::ListType::create(c10::TensorType::get()), devList); // Add a callback that runs profiling end callbacks if (work->recordFunctionEndCallback_) { work->future_->addCallback([work](at::ivalue::Future& /* unused */) { work->recordFunctionEndCallback_(); }); } work->future_->markCompleted(c10::IValue(outputTensors)); } return work; } #endif // #ifdef USE_CUDA default: { TORCH_UCC_LOG_ERROR( TORCH_UCC_COLL_POST, c10::str("unsupported device type ", dev.str())); throw std::invalid_argument(ucc_status_string(UCC_ERR_NOT_SUPPORTED)); } } } c10::intrusive_ptr ProcessGroupUCC::allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& /* unused */) { auto& tensor = inputTensors[0]; check_device(tensor.device(), outputTensors[0][0].device()); initComm(tensor.device()); if (tensor.device().is_cpu() || torch_ucc_config.use_allgatherv) { AllgathervWorkData* data = new AllgathervWorkData(size_); for (int i = 0; i < size_; i++) { data->recv_lengths[i] = tensor.element_size() * tensor.numel(); data->recv_offsets[i] = (uint64_t)outputTensors[0][i].data_ptr(); } ucc_coll_args_t coll; coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; coll.flags = UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT; coll.coll_type = UCC_COLL_TYPE_ALLGATHERV; coll.src.info.buffer = tensor.data_ptr(); coll.src.info.count = tensor.element_size() * tensor.numel(); coll.src.info.datatype = UCC_DT_UINT8; coll.src.info.mem_type = to_ucc_memType(tensor.device().type()); coll.dst.info_v.buffer = nullptr; coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data(); coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data(); coll.dst.info_v.datatype = UCC_DT_UINT8; coll.dst.info_v.mem_type = to_ucc_memType(outputTensors[0][0].device().type()); SAVE_TENSORS(inputTensors, data->src); SAVE_TENSORS(outputTensors[0], data->dst); return collective_post( OpType::ALLGATHER, []() {}, []() {}, coll, std::unique_ptr(data), tensor.device(), inputTensors, outputTensors[0], "ucc:all_gather"); } else { WorkData* data = new WorkData(); std::vector flat_output(outputTensors.size()); for (size_t i = 0; i < outputTensors.size(); i++) { TORCH_CHECK( outputTensors[i].size() == outputTensors.size() * size_, "Tensor output list is not valid for the number of participants"); flat_output[i] = c10d::newLikeFlat(outputTensors, i); } SAVE_TENSORS(flat_output, data->flat); ucc_coll_args_t coll; coll.mask = 0; coll.flags = 0; coll.coll_type = UCC_COLL_TYPE_ALLGATHER; coll.src.info.buffer = tensor.data_ptr(); coll.src.info.count = tensor.numel(); coll.src.info.datatype = to_ucc_dType(tensor); coll.src.info.mem_type = to_ucc_memType(tensor.device().type()); coll.dst.info.buffer = flat_output[0].data_ptr(); coll.dst.info.count = flat_output[0].numel(); coll.dst.info.datatype = to_ucc_dType(flat_output[0]); coll.dst.info.mem_type = to_ucc_memType(outputTensors[0][0].device().type()); auto copy_from_flat = [&] { bool asyncCopy = false; #ifdef USE_CUDA bool isCuda = outputTensors[0][0].device().is_cuda(); ; #endif for (size_t i = 0; i < outputTensors.size(); i++) { auto inumel = inputTensors[i].numel(); for (size_t j = 0; j < outputTensors[i].size(); j++) { TORCH_CHECK( (outputTensors[i][j].numel() == inumel), "Tensor operand counts must be same"); #ifdef USE_CUDA if (isCuda) { c10::cuda::CUDACachingAllocator::recordStream( outputTensors[i][j].storage().data_ptr(), (*stream)); asyncCopy = true; } #endif outputTensors[i][j].copy_(flat_output[i][j], asyncCopy); } } }; return collective_post( OpType::ALLGATHER, []() {}, copy_from_flat, coll, std::unique_ptr(data), tensor.device(), inputTensors, outputTensors[0], "ucc:all_gather"); } } c10::intrusive_ptr ProcessGroupUCC::_allgather_base( at::Tensor& outputTensor, at::Tensor& inputTensor, const AllgatherOptions& opts) { check_tensor({outputTensor}); check_tensor({inputTensor}); initComm(outputTensor.device()); WorkData* data = new WorkData(); ucc_coll_args_t coll; coll.mask = 0; coll.flags = 0; coll.coll_type = UCC_COLL_TYPE_ALLGATHER; coll.src.info.buffer = inputTensor.data_ptr(); coll.src.info.count = inputTensor.numel(); coll.src.info.datatype = ucc_dtype_map.at(inputTensor.scalar_type()); coll.src.info.mem_type = to_ucc_memType(inputTensor.device().type()); coll.dst.info.buffer = outputTensor.data_ptr(); coll.dst.info.count = outputTensor.numel(); coll.dst.info.datatype = ucc_dtype_map.at(outputTensor.scalar_type()); coll.dst.info.mem_type = to_ucc_memType(outputTensor.device().type()); std::vector inputTensors = {inputTensor}; std::vector outputTensors = {outputTensor}; SAVE_TENSORS(inputTensors, data->src); SAVE_TENSORS(outputTensors, data->dst); return collective_post( OpType::_ALLGATHER_BASE, []() {}, []() {}, coll, std::unique_ptr(data), outputTensor.device(), inputTensors, outputTensors, "ucc:allgather_base"); } c10::intrusive_ptr ProcessGroupUCC::allreduce( std::vector& tensors, const AllreduceOptions& opts) { check_tensor(tensors); auto& tensor = tensors[0]; initComm(tensor.device()); WorkData* data = new WorkData(); ucc_coll_args_t coll; coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; coll.coll_type = UCC_COLL_TYPE_ALLREDUCE; coll.op = to_ucc_reduceOp(opts.reduceOp, tensor.scalar_type()); coll.src.info.buffer = nullptr; coll.src.info.count = tensor.numel(); coll.src.info.datatype = to_ucc_dType(tensor); coll.src.info.mem_type = to_ucc_memType(tensor.device().type()); coll.dst.info.buffer = tensor.data_ptr(); coll.dst.info.count = tensor.numel(); coll.dst.info.datatype = to_ucc_dType(tensor); coll.dst.info.mem_type = to_ucc_memType(tensor.device().type()); SAVE_TENSORS(tensors, data->dst); return collective_post( OpType::ALLREDUCE, []() {}, []() {}, coll, std::unique_ptr(data), tensor.device(), tensors, tensors, "ucc:all_reduce"); } c10::intrusive_ptr ProcessGroupUCC::allreduce_coalesced( std::vector& /* unused */, const AllreduceCoalescedOptions& /* unused */) { throw std::invalid_argument( "ProcessGroupUCC does not support allreduce_coalesced"); } c10::intrusive_ptr ProcessGroupUCC::alltoall( std::vector& outputTensors, std::vector& inputTensors, const AllToAllOptions& /* unused */) { auto device = outputTensors[0].device(); for (const auto r : c10::irange(outputTensors.size())) { TORCH_CHECK( device == outputTensors[r].device() && device == inputTensors[r].device(), "Tensors must be on the same device") } initComm(device); ucc_coll_args_t coll; AlltoallWorkData* data; data = new AlltoallWorkData(size_); /* to avoid flatten the tensors, we use alltoallv to achieve Alltoall as follow. 1. store addresses of each tensor directly in displacements, keep buffer to nullptr, i.e., 0 2. convert datatype to UINT8, which is always 1 bytes, to avoid wrong size calculation in UCC layer 3. post Alltoallv */ for (const auto i : c10::irange(size_)) { data->send_lengths[i] = (uint64_t)(inputTensors[i].element_size() * inputTensors[i].numel()); data->send_offsets[i] = (uint64_t)inputTensors[i].data_ptr(); data->recv_lengths[i] = (uint64_t)(outputTensors[i].element_size() * outputTensors[i].numel()); data->recv_offsets[i] = (uint64_t)outputTensors[i].data_ptr(); } coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; coll.flags = UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT; coll.coll_type = UCC_COLL_TYPE_ALLTOALLV; coll.src.info_v.buffer = 0; coll.src.info_v.counts = (ucc_count_t*)data->send_lengths.data(); coll.src.info_v.displacements = (ucc_aint_t*)data->send_offsets.data(); coll.src.info_v.datatype = UCC_DT_UINT8; coll.src.info_v.mem_type = to_ucc_memType(inputTensors[0].device().type()); coll.dst.info_v.buffer = 0; coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data(); coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data(); coll.dst.info_v.datatype = UCC_DT_UINT8; coll.dst.info_v.mem_type = to_ucc_memType(outputTensors[0].device().type()); SAVE_TENSORS(inputTensors, data->src); SAVE_TENSORS(outputTensors, data->dst); return collective_post( OpType::ALLTOALL, []() {}, []() {}, coll, std::unique_ptr(data), device, inputTensors, outputTensors, "ucc:alltoall"); } c10::intrusive_ptr ProcessGroupUCC::alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, std::vector& inputSplitSizes, const AllToAllOptions& /* unused */) { check_device(inputTensor.device(), outputTensor.device()); initComm(inputTensor.device()); ucc_coll_args_t coll; AlltoallWorkData* data; if ((outputSplitSizes.size() == 0) && (inputSplitSizes.size() == 0)) { data = new AlltoallWorkData(0); TORCH_CHECK( (outputTensor.size(0) % size_ == 0) && (inputTensor.size(0) % size_ == 0), "Tensor's dim 0 does not divide equally across group size"); coll.mask = 0; coll.flags = 0; coll.coll_type = UCC_COLL_TYPE_ALLTOALL; coll.src.info.buffer = inputTensor.data_ptr(); coll.src.info.count = inputTensor.element_size() * inputTensor.numel(); coll.src.info.datatype = UCC_DT_UINT8; coll.src.info.mem_type = to_ucc_memType(inputTensor.device().type()); coll.dst.info.buffer = outputTensor.data_ptr(); coll.dst.info.count = outputTensor.element_size() * outputTensor.numel(); coll.dst.info.datatype = UCC_DT_UINT8; coll.dst.info.mem_type = to_ucc_memType(outputTensor.device().type()); coll.flags = 0; } else { data = new AlltoallWorkData(size_); c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_); c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_); computeLengthsAndOffsets( outputSplitSizes, outputTensor, &data->recv_lengths, &data->recv_offsets); computeLengthsAndOffsets( inputSplitSizes, inputTensor, &data->send_lengths, &data->send_offsets); coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; coll.coll_type = UCC_COLL_TYPE_ALLTOALLV; coll.src.info_v.buffer = inputTensor.data_ptr(); coll.src.info_v.counts = (ucc_count_t*)data->send_lengths.data(); coll.src.info_v.displacements = (ucc_aint_t*)data->send_offsets.data(); coll.src.info_v.datatype = to_ucc_dType(inputTensor); coll.src.info_v.mem_type = to_ucc_memType(inputTensor.device().type()); coll.dst.info_v.buffer = outputTensor.data_ptr(); coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data(); coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data(); coll.dst.info_v.datatype = to_ucc_dType(outputTensor); coll.dst.info_v.mem_type = to_ucc_memType(outputTensor.device().type()); coll.flags = UCC_COLL_ARGS_FLAG_CONTIG_SRC_BUFFER | UCC_COLL_ARGS_FLAG_CONTIG_DST_BUFFER | UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT; if (torch_ucc_config.enable_comms_logger) { logger->trace_generator->recordOptionalInfo( outputSplitSizes, inputSplitSizes); } } std::vector inputTensors = {inputTensor}; std::vector outputTensors = {outputTensor}; SAVE_TENSORS(inputTensors, data->src); SAVE_TENSORS(outputTensors, data->dst); return collective_post( OpType::ALLTOALL_BASE, []() {}, []() {}, coll, std::unique_ptr(data), inputTensor.device(), inputTensors, outputTensors, "ucc:alltoall"); } c10::intrusive_ptr ProcessGroupUCC::barrier(const BarrierOptions& opts) { c10::Device device = c10::Device(c10::DeviceType::CPU); #ifdef USE_CUDA auto numGPUs = c10::cuda::device_count(); if (!opts.device_ids.empty()) { device = c10::Device(c10::DeviceType::CUDA, opts.device_ids.front()); } else if (comm && comm->cuda_device_index != TORCH_UCC_DEVICE_NOT_SET) { device = c10::Device(c10::DeviceType::CUDA, comm->cuda_device_index); } else if (numGPUs > 0) { int8_t deviceIdx = static_cast(c10::cuda::current_device()); // if current device is 0, likely the device is not set, use the best guess if (0 == (int)deviceIdx) { deviceIdx = static_cast(this->getRank() % numGPUs); } TORCH_UCC_LOG_INFO( TORCH_UCC_COLL_POST, c10::str( "post barrier before specifying any GPU while there are ", numGPUs, " GPUs available. ", "Not clear if GPU barrier is required, using GPU ", (int)deviceIdx, " to perform barrier. ", "Specify device_ids option in barrier() to force ", "use of a particular device")); device = c10::Device(c10::DeviceType::CUDA, deviceIdx); } #endif initComm(device); ucc_coll_args_t coll; coll.mask = 0; coll.flags = 0; coll.coll_type = UCC_COLL_TYPE_BARRIER; auto dummy_tensor = std::vector(); return collective_post( OpType::BARRIER, []() {}, []() {}, coll, nullptr, device, dummy_tensor, dummy_tensor, "ucc:barrier"); } c10::intrusive_ptr ProcessGroupUCC::broadcast( std::vector& tensors, const BroadcastOptions& opts) { check_tensor(tensors); auto& tensor = tensors[0]; initComm(tensor.device()); WorkData* data = new WorkData(); ucc_coll_args_t coll; coll.mask = 0; coll.flags = 0; coll.coll_type = UCC_COLL_TYPE_BCAST; coll.src.info.buffer = tensor.data_ptr(); coll.src.info.count = tensor.numel(); coll.src.info.datatype = to_ucc_dType(tensor); coll.src.info.mem_type = to_ucc_memType(tensor.device().type()); coll.root = opts.rootRank; SAVE_TENSORS(tensors, data->dst); if (torch_ucc_config.enable_comms_logger) { logger->trace_generator->recordOptionalInfo(opts.rootRank); } return collective_post( OpType::BROADCAST, []() {}, []() {}, coll, std::unique_ptr(data), tensor.device(), tensors, tensors, "ucc:broadcast"); } c10::intrusive_ptr ProcessGroupUCC::gather( std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts) { std::vector outputs; auto& input = inputTensors[0]; initComm(input.device()); AllgathervWorkData* data = new AllgathervWorkData(size_); ucc_coll_args_t coll; coll.root = opts.rootRank; coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; coll.flags = UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT; coll.coll_type = UCC_COLL_TYPE_GATHERV; /* for non-root ranks, only src is valid */ coll.src.info.buffer = input.data_ptr(); coll.src.info.count = (uint64_t)(input.element_size() * input.numel()); coll.src.info.datatype = UCC_DT_UINT8; coll.src.info.mem_type = to_ucc_memType(input.device().type()); if (getRank() == opts.rootRank) { if (outputTensors.size() != 1) { TORCH_UCC_LOG_ERROR( TORCH_UCC_COLL_POST, c10::str( "gather requires a single-element output list containing a list with ", getSize(), " tensors.")); } else if (outputTensors[0].size() != static_cast(getSize())) { TORCH_UCC_LOG_ERROR( TORCH_UCC_COLL_POST, c10::str( "Incorrect output list size ", outputTensors[0].size(), ". Output list size should be ", getSize(), ", same as size of the process group.")); } outputs = outputTensors[0]; for (int i = 0; i < size_; i++) { data->recv_lengths[i] = (uint64_t)(outputs[i].element_size() * outputs[i].numel()); data->recv_offsets[i] = (uint64_t)outputs[i].data_ptr(); } /* use gatherv and store non-contiguous addresses in displacements to avoid * flatten outputTensors */ coll.dst.info_v.buffer = nullptr; coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data(); coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data(); coll.dst.info_v.datatype = UCC_DT_UINT8; coll.dst.info_v.mem_type = to_ucc_memType(outputs[0].device().type()); SAVE_TENSORS(outputs, data->dst); } else { // for non-root ranks, outputTensors should be an empty list if (outputTensors.size() != 0) { TORCH_UCC_LOG_ERROR( TORCH_UCC_COLL_POST, "requires empty output on non-root"); } outputs = {}; // append a empty tensor to the list to be used by future mark outputs.emplace_back(); } SAVE_TENSORS(inputTensors, data->src); return collective_post( OpType::GATHER, []() {}, []() {}, coll, std::unique_ptr(data), input.device(), inputTensors, outputs, "ucc:gather"); } c10::intrusive_ptr ProcessGroupUCC::reduce( std::vector& tensors, const ReduceOptions& opts) { check_tensor(tensors); auto& tensor = tensors[0]; initComm(tensor.device()); WorkData* data = new WorkData(); ucc_coll_args_t coll; coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; coll.coll_type = UCC_COLL_TYPE_REDUCE; coll.op = ucc_op_map.at(opts.reduceOp); coll.root = opts.rootRank; coll.src.info.buffer = tensor.data_ptr(); coll.src.info.count = tensor.numel(); coll.src.info.datatype = ucc_dtype_map.at(tensor.scalar_type()); coll.src.info.mem_type = to_ucc_memType(tensor.device().type()); coll.dst.info.buffer = tensor.data_ptr(); coll.dst.info.count = tensor.numel(); coll.dst.info.datatype = ucc_dtype_map.at(tensor.scalar_type()); coll.dst.info.mem_type = to_ucc_memType(tensor.device().type()); SAVE_TENSORS(tensors, data->dst); return collective_post( OpType::REDUCE, []() {}, []() {}, coll, std::unique_ptr(data), tensor.device(), tensors, tensors, "ucc:reduce"); } c10::intrusive_ptr ProcessGroupUCC::reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts) { TORCH_CHECK( (outputTensors.size() == inputTensors.size()), "Tensor input/output list for reduce_scatter must have same size"); check_tensor(outputTensors); check_device(inputTensors[0][0].device(), outputTensors[0].device()); initComm(inputTensors[0][0].device()); auto data = std::make_unique(); std::vector flat_input(inputTensors.size()); for (size_t i = 0; i < inputTensors.size(); i++) { TORCH_CHECK( inputTensors[i].size() == inputTensors.size() * size_, "Tensor input list is not valid for the number of participants"); flat_input[i] = c10d::newLikeFlat(inputTensors, i); } SAVE_TENSORS(flat_input, data->flat); check_tensor(flat_input); ucc_coll_args_t coll; coll.mask = 0; coll.flags = 0; coll.coll_type = UCC_COLL_TYPE_REDUCE_SCATTER; coll.op = to_ucc_reduceOp(opts.reduceOp, flat_input[0].scalar_type()); coll.src.info.buffer = flat_input[0].data_ptr(); coll.src.info.count = flat_input[0].numel(); coll.src.info.datatype = to_ucc_dType(flat_input[0]); coll.src.info.mem_type = to_ucc_memType(flat_input[0].device().type()); coll.dst.info.buffer = outputTensors[0].data_ptr(); coll.dst.info.count = outputTensors[0].numel(); coll.dst.info.datatype = to_ucc_dType(outputTensors[0]); coll.dst.info.mem_type = to_ucc_memType(outputTensors[0].device().type()); SAVE_TENSORS(inputTensors[0], data->src); SAVE_TENSORS(outputTensors, data->dst); auto copy_to_flat = [&] { bool asyncCopy = false; auto isize = inputTensors.size(); #ifdef USE_CUDA bool isCuda = inputTensors[0][0].device().is_cuda(); #endif for (size_t i = 0; i < isize; i++) { auto onumel = outputTensors[i].numel(); for (size_t j = 0; j < inputTensors[i].size(); j++) { TORCH_CHECK( (inputTensors[i][j].numel() == onumel), "Tensor operand counts must be same"); #ifdef USE_CUDA if (isCuda) { c10::cuda::CUDACachingAllocator::recordStream( inputTensors[i][j].storage().data_ptr(), (*stream)); asyncCopy = true; } #endif flat_input[i][j].copy_(inputTensors[i][j], asyncCopy); } } }; return collective_post( OpType::REDUCE_SCATTER, copy_to_flat, []() {}, coll, std::move(data), inputTensors[0][0].device(), inputTensors[0], outputTensors, "ucc:reduce_scatter"); } c10::intrusive_ptr ProcessGroupUCC::scatter( std::vector& outputTensors, std::vector>& inputTensors, const ScatterOptions& opts) { auto& tensor = outputTensors[0]; initComm(tensor.device()); ScattervWorkData* data = new ScattervWorkData(size_); ucc_coll_args_t coll; coll.root = opts.rootRank; coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; coll.flags = UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT; coll.coll_type = UCC_COLL_TYPE_SCATTERV; if (getRank() == opts.rootRank) { /* src is only valid at non-root rank */ if (inputTensors.size() != 1) { TORCH_UCC_LOG_ERROR( TORCH_UCC_COLL_POST, c10::str( "gather requires a single-element output list containing a list with ", getSize(), " tensors.")); } else if (inputTensors[0].size() != static_cast(getSize())) { TORCH_UCC_LOG_ERROR( TORCH_UCC_COLL_POST, c10::str( "Incorrect output list size ", inputTensors[0].size(), ". Output list size should be ", getSize(), ", same as size of the process group.")); } for (int i = 0; i < size_; i++) { data->send_lengths[i] = (uint64_t)tensor.element_size() * tensor.numel(); data->send_offsets[i] = (uint64_t)inputTensors[0][i].data_ptr(); } /* use scatter and store non-contiguous addresses in displacements to avoid * flatten inputTensors */ coll.src.info_v.buffer = nullptr; coll.src.info_v.counts = (ucc_count_t*)data->send_lengths.data(); coll.src.info_v.displacements = (ucc_aint_t*)data->send_offsets.data(); coll.src.info_v.datatype = UCC_DT_UINT8; coll.src.info_v.mem_type = to_ucc_memType(inputTensors[0][0].device().type()); SAVE_TENSORS(inputTensors[0], data->src); } else { // for non-root ranks, inputTensors should be an empty list if (inputTensors.size() != 0) { TORCH_UCC_LOG_ERROR( TORCH_UCC_COLL_POST, "requires empty output on non-root"); } } coll.dst.info.buffer = tensor.data_ptr(); coll.dst.info.count = (uint64_t)tensor.element_size() * tensor.numel(); coll.dst.info.datatype = UCC_DT_UINT8; coll.dst.info.mem_type = to_ucc_memType(tensor.device().type()); SAVE_TENSORS(outputTensors, data->dst); return collective_post( OpType::SCATTER, []() {}, []() {}, coll, std::unique_ptr(data), tensor.device(), (getRank() == opts.rootRank) ? inputTensors[0] : outputTensors, outputTensors, "ucc:scatter"); } c10::intrusive_ptr ProcessGroupUCC::send( std::vector& tensors, int dstRank, int tag) { check_tensor(tensors); auto& tensor = tensors[0]; initComm(tensor.device()); WorkData* data = new WorkData(); ucc_coll_args_t coll; coll.tag = tag; coll.mask = UCC_COLL_ARGS_FIELD_ACTIVE_SET | UCC_COLL_ARGS_FIELD_TAG; coll.flags = 0; coll.coll_type = UCC_COLL_TYPE_BCAST; coll.src.info.buffer = tensor.data_ptr(); coll.src.info.count = tensor.numel(); coll.src.info.datatype = to_ucc_dType(tensor); coll.src.info.mem_type = to_ucc_memType(tensor.device().type()); coll.root = getRank(); coll.active_set.size = 2; coll.active_set.start = getRank(); coll.active_set.stride = dstRank - getRank(); SAVE_TENSORS(tensors, data->dst); return collective_post( OpType::SEND, []() {}, []() {}, coll, std::unique_ptr(data), tensor.device(), tensors, tensors, "ucc:send"); } c10::intrusive_ptr ProcessGroupUCC::recv( std::vector& tensors, int srcRank, int tag) { check_tensor(tensors); auto& tensor = tensors[0]; initComm(tensor.device()); WorkData* data = new WorkData(); ucc_coll_args_t coll; coll.tag = tag; coll.mask = UCC_COLL_ARGS_FIELD_ACTIVE_SET | UCC_COLL_ARGS_FIELD_TAG; coll.flags = 0; coll.coll_type = UCC_COLL_TYPE_BCAST; coll.src.info.buffer = tensor.data_ptr(); coll.src.info.count = tensor.numel(); coll.src.info.datatype = to_ucc_dType(tensor); coll.src.info.mem_type = to_ucc_memType(tensor.device().type()); coll.root = srcRank; coll.active_set.size = 2; coll.active_set.start = srcRank; coll.active_set.stride = getRank() - srcRank; SAVE_TENSORS(tensors, data->dst); return collective_post( OpType::RECV, []() {}, []() {}, coll, std::unique_ptr(data), tensor.device(), tensors, tensors, "ucc:recv"); } void ProcessGroupUCC::setSequenceNumberForGroup() {} uint64_t ProcessGroupUCC::getSequenceNumberForGroup() { return seq_; } c10::intrusive_ptr ProcessGroupUCC::createProcessGroupUCC( const c10::intrusive_ptr<::c10d::Store>& store, int rank, int size, const std::chrono::duration& timeout) { return c10::make_intrusive(store, rank, size, timeout); } void ProcessGroupUCC::initComm(c10::Device dev) { if (!comm) { #ifdef USE_CUDA if (dev.is_cuda()) { c10::cuda::set_device(dev.index()); } #endif comm = Comm::get_comm(comm_id, dev, oob, logger); TORCH_UCC_LOG_INFO(TORCH_UCC_INIT, "Successfully initialized UCX library"); comm->ucc_create_team(team, oob); TORCH_UCC_LOG_INFO(TORCH_UCC_INIT, "Successfully initialized UCC library"); logger->setPhase(TORCH_UCC_READY); } else { if (dev.is_cuda()) { if ((comm->cuda_device_index != TORCH_UCC_DEVICE_NOT_SET) && (comm->cuda_device_index != dev.index())) { TORCH_UCC_LOG_ERROR( TORCH_UCC_INIT, "ucc communicator was initialized with different cuda device," "multi device is not supported"); throw std::invalid_argument(ucc_status_string(UCC_ERR_NOT_SUPPORTED)); } comm->cuda_device_index = dev.index(); } } #ifdef USE_CUDA // Create UCC execution engine. if (!cuda_ee && dev.is_cuda()) { stream = std::make_unique( at::cuda::getStreamFromPool(true, dev.index())); ucc_ee_params_t params; params.ee_type = UCC_EE_CUDA_STREAM; params.ee_context = (void*)stream->stream(); params.ee_context_size = sizeof(cudaStream_t); TORCH_UCC_CHECK( ucc_ee_create(team, ¶ms, &cuda_ee), "failed to create UCC execution engine"); for (int i = 0; i < 2; i++) { stream_p2p[i] = std::make_unique( at::cuda::getStreamFromPool(true, dev.index())); ucc_ee_params_t params; params.ee_type = UCC_EE_CUDA_STREAM; params.ee_context = (void*)stream_p2p[i]->stream(); params.ee_context_size = sizeof(cudaStream_t); TORCH_UCC_CHECK( ucc_ee_create(team, ¶ms, &cuda_ee_p2p[i]), "failed to create UCC P2P execution engine"); } } #endif } } // namespace c10d #endif // USE_C10D_UCC