• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2from typing import List, Optional
3
4import torch
5import torch.distributed.distributed_c10d as c10d
6
7
8"""
9This file contains the op impls for the legacy (c10d_functional) functional collectives.
10These impls simply call into the native (_c10d_functional) functional collectives.
11"""
12
13
14def _broadcast(input, src, tag, ranks, group_size):
15    group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
16    return torch.ops._c10d_functional.broadcast(
17        input,
18        src,
19        group_name,
20    )
21
22
23def _all_reduce(input, reduce_op, tag, ranks, group_size):
24    group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
25    return torch.ops._c10d_functional.all_reduce(
26        input,
27        reduce_op,
28        group_name,
29    )
30
31
32def _all_reduce_coalesced(inputs, reduce_op, tag, ranks, group_size):
33    group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
34    return torch.ops._c10d_functional.all_reduce_coalesced(
35        inputs,
36        reduce_op,
37        group_name,
38    )
39
40
41def _all_gather_into_tensor(input, tag, ranks, group_size):
42    group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
43    return torch.ops._c10d_functional.all_gather_into_tensor(
44        input,
45        group_size,
46        group_name,
47    )
48
49
50def _all_gather_into_tensor_coalesced(input, tag, ranks, group_size):
51    group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
52    return torch.ops._c10d_functional.all_gather_into_tensor_coalesced(
53        input,
54        group_size,
55        group_name,
56    )
57
58
59def _reduce_scatter_tensor(
60    input: torch.Tensor,
61    reduce_op: str,
62    tag: str,
63    ranks: List[int],
64    group_size: int,
65):
66    group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
67    return torch.ops._c10d_functional.reduce_scatter_tensor(
68        input,
69        reduce_op,
70        group_size,
71        group_name,
72    )
73
74
75def _reduce_scatter_tensor_coalesced(
76    inputs: List[torch.Tensor],
77    reduce_op: str,
78    tag: str,
79    ranks: List[int],
80    group_size: int,
81):
82    group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
83    return torch.ops._c10d_functional.reduce_scatter_tensor_coalesced(
84        inputs,
85        reduce_op,
86        group_size,
87        group_name,
88    )
89
90
91def _all_to_all_single(
92    input: torch.Tensor,
93    output_split_sizes: Optional[List[int]],
94    input_split_sizes: Optional[List[int]],
95    tag: str,
96    ranks: List[int],
97    group_size: int,
98):
99    if output_split_sizes is None or input_split_sizes is None:
100        assert output_split_sizes is None and input_split_sizes is None, (
101            "output_split_sizes and input_split_sizes must either be "
102            "specified together or both set to None"
103        )
104        output_split_sizes = [input.shape[0] // group_size] * group_size
105        input_split_sizes = output_split_sizes
106
107    group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
108    return torch.ops._c10d_functional.all_to_all_single(
109        input,
110        output_split_sizes,
111        input_split_sizes,
112        group_name,
113    )
114
115
116def _wait_tensor(tensor: torch.Tensor) -> torch.Tensor:
117    return torch.ops._c10d_functional.wait_tensor(tensor)
118