#pragma once #include #include namespace c10d { namespace symmetric_memory { // SymmetricMemory represents symmetric allocations across a group of devices. // The allocations represented by a SymmetricMemory object are accessible by // all devices in the group. The class can be used for op-level custom // communication patterns (via the get_buffer APIs and the synchronization // primitives), as well as custom communication kernels (via the buffer and // signal_pad device pointers). // // To acquire a SymmetricMemory object, each rank first allocates // identical-sized memory via SymmetricMemoryAllocator::alloc(), then invokes // SymmetricMemoryAllocator::rendezvous() on the memory to establish the // association across peer buffers. The rendezvous is a one-time process, and // the mapping between a local memory memory and the associated SymmetricMemory // object is unique. // // NOTE [symmetric memory signal pad] // Signal pads are P2P-accessible memory regions designated for // synchronization. SymmetricMemory offers built-in synchronization primitives // such as barriers, put_signal, and wait_signal, which are all based on signal // pads. Users may utilize signal pads for their own synchronization logic, // provided that the signal pads remain zero-filled following successful // synchronization. // // NOTE [symmetric memory synchronization channel] // Synchronization channels allow users to use a single SymmetricMemory object // to perform isolated synchronizations on different streams. For example, // consider the case in which two barriers are issued on two streams for // different purposes. Without the concept of channels, we cannot guarantee the // correctness of the barriers since signals issued from barrier on stream A // can be received by the barrier on stream B. By specifying different channels // for these two barriers, they can operate correctly in parallel. class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target { public: virtual ~SymmetricMemory() {} virtual std::vector get_buffer_ptrs() = 0; virtual std::vector get_signal_pad_ptrs() = 0; // get_buffer_ptrs_dev() and get_signal_pad_ptrs_dev() each return a pointer // to a device array of size world_size, containing buffer pointers and // signal pad pointers, respectively. virtual void** get_buffer_ptrs_dev() = 0; virtual void** get_signal_pad_ptrs_dev() = 0; virtual size_t get_buffer_size() = 0; virtual size_t get_signal_pad_size() = 0; virtual bool has_multicast_support() = 0; virtual void* get_multicast_ptr() = 0; virtual at::Tensor get_buffer( int rank, c10::IntArrayRef sizes, c10::ScalarType dtype, int64_t storage_offset) = 0; virtual void barrier(int channel) = 0; virtual void put_signal(int dst_rank, int channel) = 0; virtual void wait_signal(int src_rank, int channel) = 0; virtual int get_rank() = 0; virtual int get_world_size() = 0; }; class SymmetricMemoryAllocator : public c10::intrusive_ptr_target { public: virtual ~SymmetricMemoryAllocator(){}; virtual void* alloc( size_t size, int device_idx, const std::string& group_name) = 0; virtual void free(void* ptr) = 0; virtual size_t get_alloc_size(void* ptr) = 0; virtual c10::intrusive_ptr rendezvous(void* ptr) = 0; virtual bool is_rendezvous_completed(void* ptr) = 0; virtual bool has_multicast_support(int device_idx) = 0; }; C10_EXPORT bool is_finalizing(); C10_EXPORT void register_allocator( c10::DeviceType device_type, c10::intrusive_ptr allocator); C10_EXPORT c10::intrusive_ptr get_allocator( c10::DeviceType device_type); // Set a store for rendezvousing symmetric allocations on a group of devices // identified by `group_name`. The concept of groups is logical; users can // utilize predefined groups (e.g., a group of device identified by a // ProcessGroup) or create custom ones. Note that a SymmetricMemoryAllocator // backends might employ a more efficient communication channel for the actual // rendezvous process and only use the store for bootstrapping purposes. TORCH_API void set_group_info( const std::string& group_name, int rank, int world_size, c10::intrusive_ptr store); struct GroupInfo { int rank; int world_size; c10::intrusive_ptr store; }; C10_EXPORT const GroupInfo& get_group_info(const std::string& group_name); // Identical to empty_strided, but allows symmetric memory access to be // established for the allocated tensor via SymmetricMemory::rendezvous(). This // function itself is not a collective operation. It invokes // SymmetricMemoryAllocator::alloc() for the requested device under the hood. // // NOTE [symmetric memory persistent allocation] // If an `alloc_id` is supplied, empty_strided_p2p will perform persistent // allocation. This makes the function cache allocated memory and ensure that // invocations with the same `alloc_id` receive tensors backed by the same // memory address. For safety, if a previous persistent allocation is still // active (i.e., the storage of the returned tensor is still alive), persistent // allocations with the same `alloc_id` will fail. This determinism coupled // with memory planning of communication buffers (e.g., by Inductor) allows // communication algorithms to reliably reuse previously established remote // memory access. TORCH_API at::Tensor empty_strided_p2p( c10::IntArrayRef size, c10::IntArrayRef stride, c10::ScalarType dtype, c10::Device device, const std::string& group_name, std::optional alloc_id); // Establishes symmetric memory access on tensors allocated via // empty_strided_p2p() and empty_strided_p2p_persistent(). rendezvous() is a // one-time process, and the mapping between a local memory region and the // associated SymmetricMemory object is unique. Subsequent calls to // rendezvous() with the same tensor, or tensors allocated with // empty_strided_p2p_persistent() using the same alloc_id, will receive the // cached SymmetricMemory object. // // The function has a collective semantic and must be invoked simultaneously // from all rendezvous participants. TORCH_API c10::intrusive_ptr rendezvous( const at::Tensor& tensor); // Returns the SymmetricMemory object associated with the tensor. It can only // be invoked after rendezvous() but does not need to be invoked collectively. TORCH_API c10::intrusive_ptr get_symmetric_memory( const at::Tensor& tensor); TORCH_API bool has_multicast_support( c10::DeviceType device_type, int device_idx); } // namespace symmetric_memory } // namespace c10d