• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import functools
2import math
3from enum import IntEnum
4
5import sympy
6
7import torch
8
9from . import ir
10from .utils import get_dtype_size, sympy_product
11from .virtualized import V
12
13
14class NCCL_COLL(IntEnum):
15    ALL_REDUCE = 0
16    ALL_GATHER = 1
17    REDUCE_SCATTER = 2
18
19
20class NVIDIA_GPU_TYPE(IntEnum):
21    VOLTA = 0
22    AMPERE = 1
23    HOPPER = 2
24
25
26@functools.lru_cache
27def get_gpu_type() -> NVIDIA_GPU_TYPE:
28    gpu_info = torch.utils.collect_env.get_gpu_info(torch.utils.collect_env.run) or ""
29    if "V100" in gpu_info:
30        return NVIDIA_GPU_TYPE.VOLTA
31    elif "A100" in gpu_info:
32        return NVIDIA_GPU_TYPE.AMPERE
33    elif "H100" in gpu_info:
34        return NVIDIA_GPU_TYPE.HOPPER
35    else:
36        # for other gpu types, assume Ampere
37        return NVIDIA_GPU_TYPE.AMPERE
38
39
40def get_collective_type(node: ir.IRNode) -> NCCL_COLL:
41    if not isinstance(node, ir._CollectiveKernel):
42        raise ValueError(f"node is not a collective kernel: {node}")
43
44    kernel_name = node.python_kernel_name
45    assert kernel_name is not None
46    if "all_reduce" in kernel_name:
47        return NCCL_COLL.ALL_REDUCE
48    elif "all_gather" in kernel_name:
49        return NCCL_COLL.ALL_GATHER
50    elif "reduce_scatter" in kernel_name:
51        return NCCL_COLL.REDUCE_SCATTER
52    else:
53        raise ValueError(f"Unsupported collective kernel: {kernel_name}")
54
55
56def get_collective_input_size_bytes(node: ir.IRNode) -> int:
57    sz_bytes = 0
58    for inp in node.inputs:  # type: ignore[attr-defined]
59        numel = sympy_product(inp.layout.size)
60        if isinstance(numel, sympy.Integer):
61            # For ease of testing
62            numel = int(numel)
63        else:
64            numel = V.graph.sizevars.size_hint(numel, fallback=0)
65        sz_bytes += numel * get_dtype_size(inp.layout.dtype)
66    return sz_bytes
67
68
69def get_collective_group_size(node: ir.IRNode) -> int:
70    if type(node) == ir._CollectiveKernel:
71        from torch.distributed.distributed_c10d import _get_group_size_by_name
72
73        return _get_group_size_by_name(node.constant_args[-1])
74    else:
75        raise TypeError(f"Unsupported collective type: {node}")
76
77
78####################################################################################################################
79# The following code and constants are adapted from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc #
80####################################################################################################################
81
82
83class NCCL_HW(IntEnum):
84    NVLINK = 0
85    PCI = 1
86    NET = 2
87
88
89class NCCL_ALGO(IntEnum):
90    TREE = 0
91    RING = 1
92
93
94class NCCL_PROTO(IntEnum):
95    # The ordering and enum values here matches original in
96    # https://github.com/NVIDIA/nccl/blob/0b083e52096c387bad7a5c5c65b26a9dca54de8c/src/include/devcomm.h#L28
97    # For difference between these protocols, see https://github.com/NVIDIA/nccl/issues/281#issuecomment-571816990
98    LL = 0  # Low-latency
99    # LL128 = 1   # Low-latency 128-byte
100    # SIMPLE = 2
101
102
103# Latencies in us
104# len(NCCL_ALGO) x len(NCCL_PROTO)
105# NOTE: use array instead of tensor to prevent incompatibility with fake mode
106baseLat = [
107    # Tree
108    [
109        6.8,  # LL
110    ],
111    # Ring
112    [
113        6.6,  # LL
114    ],
115]
116
117# Latencies in us
118# len(NCCL_HW) x len(NCCL_ALGO) x len(NCCL_PROTO)
119hwLat = [
120    # NVLINK
121    [
122        [0.6],  # Tree (LL)
123        [0.6],  # Ring (LL)
124    ],
125    # PCI
126    [
127        [1.0],  # Tree (LL)
128        [1.0],  # Ring (LL)
129    ],
130    # NET
131    [
132        [5.0],  # Tree (LL)
133        [2.7],  # Ring (LL)
134    ],
135]
136
137
138# LL128 max BW per channel
139llMaxBws = [
140    # Volta-N1/Intel-N2/Intel-N4
141    [
142        39.0,
143        39.0,
144        20.4,
145    ],
146    # Ampere-N1/AMD-N2/AMD-N4
147    [
148        87.7,
149        22.5,  # avg of ring & tree
150        19.0,
151    ],
152    # Hopper-N1/AMD-N2/AMD-N4
153    [
154        87.7,
155        22.5,  # avg of ring & tree
156        19.0,
157    ],
158]
159
160
161def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
162    """
163    Returns estimated NCCL collective runtime in nanoseconds (ns).
164
165    The following heuristics are copied from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc.
166    We aim to estimate the runtime as accurately as possible.
167
168    Assumptions:
169    - only ring algorithm (NCCL_ALGO_RING) is used
170    - only Low-Latency protocol (NCCL_PROTO_LL) is used, i.e. Simple or LL128 is not used
171    - 8 gpus per node  # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info.
172    - collective is one of: allreduce, reducescatter, allgather
173    """
174    tensor_storage_size_bytes = get_collective_input_size_bytes(node)
175    # Convert bytes to GB
176    tensor_storage_size_GB = tensor_storage_size_bytes / 1024 / 1024 / 1024
177
178    # Currently assumes each node has 8 gpus. And when >1 node is used, assumes each node uses all 8 gpus.
179    # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info.
180    num_gpus_per_node = 8
181    group_size = get_collective_group_size(node)
182    nNodes = math.ceil(group_size / num_gpus_per_node)
183    nRanks = group_size  # this is total # of gpus globally that participate in this collective op
184
185    if nRanks <= 1:
186        return 0
187
188    # Assumes ring algorithm
189    nccl_algo = NCCL_ALGO.RING
190    nccl_proto = NCCL_PROTO.LL
191    coll = get_collective_type(node)
192
193    # =============== bandwidth computation ===============
194    # First compute bandwidth in GB/s; then at the end, convert it to GB/ns
195
196    bwIntra = torch._inductor.config.intra_node_bw
197    bwInter = torch._inductor.config.inter_node_bw
198
199    compCapIndex = get_gpu_type()
200    index2 = nNodes - 1 if nNodes <= 2 else 2
201    # LL: for single node, we look at GPU type; for multi-node, we look at CPU type
202    index1 = compCapIndex if nNodes == 1 else 0
203    llMaxBw = llMaxBws[index1][index2]
204
205    # NOTE: each step of ring algorithm is synchronized,
206    # and is bottlenecked by the slowest link which is the inter-node interconnect.
207    # hence when nNodes >= 2, bw is inter-node bandwidth.
208    # NOTE: the original code in https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc
209    # have this as `if nNodes <= 2` which seems wrong. Corrected it here.
210    bw = bwIntra if nNodes == 1 else bwInter
211    nChannels = 2  # Assume # channels is 2
212    busBw = nChannels * bw
213
214    # Various model refinements
215    busBw = min(
216        llMaxBw,
217        busBw
218        * (1.0 / 4.0 if (nNodes > 1 or coll == NCCL_COLL.ALL_REDUCE) else 1.0 / 3.0),
219    )
220
221    if coll == NCCL_COLL.ALL_REDUCE:
222        nsteps = 2 * (nRanks - 1)
223    elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER):
224        nsteps = nRanks - 1
225
226    # Convert bus BW to algorithm BW (tensor bytes / algoBW = actual execution time)
227    ratio = (1.0 * nRanks) / nsteps  # type: ignore[possibly-undefined]
228    bandwidth = busBw * ratio
229    # Convert GB/s to GB/ns
230    bandwidth_GB_per_ns = bandwidth / 1e9
231
232    # =============== latency computation ===============
233    intraHw = NCCL_HW.NVLINK
234
235    if coll == NCCL_COLL.ALL_REDUCE:
236        if nNodes > 1:
237            nInterSteps = 2 * nNodes
238        else:
239            nInterSteps = 0
240    elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER):
241        nInterSteps = nNodes - 1
242
243    # First compute latency in us; then at the end, convert it to ns
244    latency = baseLat[nccl_algo][nccl_proto]
245    intraLat = hwLat[intraHw][nccl_algo][nccl_proto]
246    interLat = hwLat[NCCL_HW.NET][nccl_algo][nccl_proto]
247
248    # Inter-node rings still have to launch nsteps * net overhead.
249    netOverhead = 0.0
250    if nNodes > 1:
251        netOverhead = 1.0  # getNetOverhead(comm);
252    intraLat = max(intraLat, netOverhead)
253    latency += (nsteps - nInterSteps) * intraLat + nInterSteps * interLat  # type: ignore[possibly-undefined]
254    # Convert us to ns
255    latency_ns = latency * 1e3
256
257    # =============== final result ===============
258    transport_ns = tensor_storage_size_GB / bandwidth_GB_per_ns
259    return transport_ns + latency_ns
260
261
262################################################################################################################
263# The above code and constants are adapted from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc #
264################################################################################################################
265