1# mypy: allow-untyped-defs 2from typing import List, Optional 3 4import torch 5import torch.distributed.distributed_c10d as c10d 6 7 8""" 9This file contains the op impls for the legacy (c10d_functional) functional collectives. 10These impls simply call into the native (_c10d_functional) functional collectives. 11""" 12 13 14def _broadcast(input, src, tag, ranks, group_size): 15 group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) 16 return torch.ops._c10d_functional.broadcast( 17 input, 18 src, 19 group_name, 20 ) 21 22 23def _all_reduce(input, reduce_op, tag, ranks, group_size): 24 group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) 25 return torch.ops._c10d_functional.all_reduce( 26 input, 27 reduce_op, 28 group_name, 29 ) 30 31 32def _all_reduce_coalesced(inputs, reduce_op, tag, ranks, group_size): 33 group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) 34 return torch.ops._c10d_functional.all_reduce_coalesced( 35 inputs, 36 reduce_op, 37 group_name, 38 ) 39 40 41def _all_gather_into_tensor(input, tag, ranks, group_size): 42 group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) 43 return torch.ops._c10d_functional.all_gather_into_tensor( 44 input, 45 group_size, 46 group_name, 47 ) 48 49 50def _all_gather_into_tensor_coalesced(input, tag, ranks, group_size): 51 group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) 52 return torch.ops._c10d_functional.all_gather_into_tensor_coalesced( 53 input, 54 group_size, 55 group_name, 56 ) 57 58 59def _reduce_scatter_tensor( 60 input: torch.Tensor, 61 reduce_op: str, 62 tag: str, 63 ranks: List[int], 64 group_size: int, 65): 66 group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) 67 return torch.ops._c10d_functional.reduce_scatter_tensor( 68 input, 69 reduce_op, 70 group_size, 71 group_name, 72 ) 73 74 75def _reduce_scatter_tensor_coalesced( 76 inputs: List[torch.Tensor], 77 reduce_op: str, 78 tag: str, 79 ranks: List[int], 80 group_size: int, 81): 82 group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) 83 return torch.ops._c10d_functional.reduce_scatter_tensor_coalesced( 84 inputs, 85 reduce_op, 86 group_size, 87 group_name, 88 ) 89 90 91def _all_to_all_single( 92 input: torch.Tensor, 93 output_split_sizes: Optional[List[int]], 94 input_split_sizes: Optional[List[int]], 95 tag: str, 96 ranks: List[int], 97 group_size: int, 98): 99 if output_split_sizes is None or input_split_sizes is None: 100 assert output_split_sizes is None and input_split_sizes is None, ( 101 "output_split_sizes and input_split_sizes must either be " 102 "specified together or both set to None" 103 ) 104 output_split_sizes = [input.shape[0] // group_size] * group_size 105 input_split_sizes = output_split_sizes 106 107 group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) 108 return torch.ops._c10d_functional.all_to_all_single( 109 input, 110 output_split_sizes, 111 input_split_sizes, 112 group_name, 113 ) 114 115 116def _wait_tensor(tensor: torch.Tensor) -> torch.Tensor: 117 return torch.ops._c10d_functional.wait_tensor(tensor) 118