#include #include #include #include #include #include #include #include #include #include namespace torch::cuda::python { void initCommMethods(PyObject* module) { auto m = py::cast(module); m.def( "_broadcast_coalesced", [](std::vector& tensors, const std::vector& devices, size_t buffer_size) { return broadcast_coalesced(tensors, devices, buffer_size); }, py::arg("tensors"), py::arg("devices"), py::arg("buffer_size"), py::call_guard()) .def( "_broadcast", [](at::Tensor& tensor, std::vector devices) { return broadcast(tensor, devices); }, py::call_guard(), py::arg("tensor"), py::arg("devices")) .def( "_broadcast_out", [](at::Tensor& tensor, std::vector& out_tensors) { return broadcast_out(tensor, out_tensors); }, py::call_guard(), py::arg("tensor"), py::arg("out")) .def( "_scatter", [](at::Tensor& tensor, std::vector& devices, std::optional> chunk_sizes, int64_t dim, std::optional py_streams) { std::optional>> streams; if (py_streams) { py::handle handle = *py_streams; streams = THPUtils_PySequence_to_CUDAStreamList(handle.ptr()); } // Note: We're holding the GIL up to here. pybind11::gil_scoped_release no_gil; return scatter(tensor, devices, chunk_sizes, dim, streams); }, py::arg("tensor"), py::arg("devices"), py::arg("chunk_sizes"), py::arg("dim"), py::arg("streams")) .def( "_scatter_out", [](at::Tensor& tensor, std::vector& out_tensors, int64_t dim, std::optional py_streams) { std::optional>> streams; if (py_streams) { py::handle handle = *py_streams; streams = THPUtils_PySequence_to_CUDAStreamList(handle.ptr()); } // Note: We're holding the GIL up to here. pybind11::gil_scoped_release no_gil; return scatter_out(tensor, out_tensors, dim, streams); }, py::arg("tensor"), py::arg("out"), py::arg("dim"), py::arg("streams")) .def( "_gather", [](std::vector& tensors, int64_t dim, std::optional destination_index) { return gather(tensors, dim, destination_index); }, py::arg("tensors"), py::arg("dim"), py::arg("destination_index"), py::call_guard()) .def( "_gather_out", [](std::vector& tensors, at::Tensor& out_tensor, int64_t dim) { return gather_out(tensors, out_tensor, dim); }, py::arg("tensors"), py::arg("out"), py::arg("dim"), py::call_guard()); } } // namespace torch::cuda::python