#include #include #include #include #include #include // Cargo culted partially from csrc/distributed/c10d/init.cpp // and partially from csrc/cuda/Stream.cpp. // THCPStream_init is also declared at global scope. // Because THCPGraph_init is forward declared in the only consumer // (csrc/Module.cpp) I don't think we need a Graph.h. template using shared_ptr_class_ = py::class_>; void THCPGraph_init(PyObject* module) { // Pybind11 patch notes say "py::module_" is more up-to-date syntax, // but CI linter and some builds prefer "module". auto torch_C_m = py::handle(module).cast(); torch_C_m.def("_graph_pool_handle", &::at::cuda::graph_pool_handle); shared_ptr_class_<::at::cuda::CUDAGraph>(torch_C_m, "_CUDAGraph") .def(py::init<>()) .def( "capture_begin", [](::at::cuda::CUDAGraph& self, std::optional pool_opt, std::string capture_error_mode) { cudaStreamCaptureMode capture_mode; c10::cuda::MempoolId_t pool = pool_opt.has_value() ? pool_opt.value() : c10::cuda::MempoolId_t{0, 0}; if (capture_error_mode == "global") { capture_mode = cudaStreamCaptureModeGlobal; } else if (capture_error_mode == "thread_local") { capture_mode = cudaStreamCaptureModeThreadLocal; } else if (capture_error_mode == "relaxed") { capture_mode = cudaStreamCaptureModeRelaxed; } else { TORCH_CHECK( false, "Unknown capture error mode. Expected `global`, `thread_local`, or `relaxed`, got ", capture_error_mode); } return self.capture_begin(pool, capture_mode); }, py::arg("pool"), py::arg("capture_error_mode"), py::call_guard()) .def( "capture_end", torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::capture_end)) .def( "register_generator_state", [](::at::cuda::CUDAGraph& self, py::handle raw_generator) { auto generator = THPGenerator_Unwrap(raw_generator.ptr()); // We've unwrapped Python object to C++ object, // so we could release GIL before calling into C++ py::gil_scoped_release release; return self.register_generator_state(generator); }, py::arg("generator")) .def( "replay", torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::replay)) .def( "reset", torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::reset)) .def( "pool", torch::wrap_pybind_function_no_gil(&at::cuda::CUDAGraph::pool)) .def( "debug_dump", torch::wrap_pybind_function_no_gil( &::at::cuda::CUDAGraph::debug_dump)) .def( "enable_debug_mode", torch::wrap_pybind_function_no_gil( &::at::cuda::CUDAGraph::enable_debug_mode)) .def( "debug_dump", torch::wrap_pybind_function_no_gil( &::at::cuda::CUDAGraph::debug_dump), py::arg("debug_path")); }