1import functools 2import math 3from enum import IntEnum 4 5import sympy 6 7import torch 8 9from . import ir 10from .utils import get_dtype_size, sympy_product 11from .virtualized import V 12 13 14class NCCL_COLL(IntEnum): 15 ALL_REDUCE = 0 16 ALL_GATHER = 1 17 REDUCE_SCATTER = 2 18 19 20class NVIDIA_GPU_TYPE(IntEnum): 21 VOLTA = 0 22 AMPERE = 1 23 HOPPER = 2 24 25 26@functools.lru_cache 27def get_gpu_type() -> NVIDIA_GPU_TYPE: 28 gpu_info = torch.utils.collect_env.get_gpu_info(torch.utils.collect_env.run) or "" 29 if "V100" in gpu_info: 30 return NVIDIA_GPU_TYPE.VOLTA 31 elif "A100" in gpu_info: 32 return NVIDIA_GPU_TYPE.AMPERE 33 elif "H100" in gpu_info: 34 return NVIDIA_GPU_TYPE.HOPPER 35 else: 36 # for other gpu types, assume Ampere 37 return NVIDIA_GPU_TYPE.AMPERE 38 39 40def get_collective_type(node: ir.IRNode) -> NCCL_COLL: 41 if not isinstance(node, ir._CollectiveKernel): 42 raise ValueError(f"node is not a collective kernel: {node}") 43 44 kernel_name = node.python_kernel_name 45 assert kernel_name is not None 46 if "all_reduce" in kernel_name: 47 return NCCL_COLL.ALL_REDUCE 48 elif "all_gather" in kernel_name: 49 return NCCL_COLL.ALL_GATHER 50 elif "reduce_scatter" in kernel_name: 51 return NCCL_COLL.REDUCE_SCATTER 52 else: 53 raise ValueError(f"Unsupported collective kernel: {kernel_name}") 54 55 56def get_collective_input_size_bytes(node: ir.IRNode) -> int: 57 sz_bytes = 0 58 for inp in node.inputs: # type: ignore[attr-defined] 59 numel = sympy_product(inp.layout.size) 60 if isinstance(numel, sympy.Integer): 61 # For ease of testing 62 numel = int(numel) 63 else: 64 numel = V.graph.sizevars.size_hint(numel, fallback=0) 65 sz_bytes += numel * get_dtype_size(inp.layout.dtype) 66 return sz_bytes 67 68 69def get_collective_group_size(node: ir.IRNode) -> int: 70 if type(node) == ir._CollectiveKernel: 71 from torch.distributed.distributed_c10d import _get_group_size_by_name 72 73 return _get_group_size_by_name(node.constant_args[-1]) 74 else: 75 raise TypeError(f"Unsupported collective type: {node}") 76 77 78#################################################################################################################### 79# The following code and constants are adapted from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc # 80#################################################################################################################### 81 82 83class NCCL_HW(IntEnum): 84 NVLINK = 0 85 PCI = 1 86 NET = 2 87 88 89class NCCL_ALGO(IntEnum): 90 TREE = 0 91 RING = 1 92 93 94class NCCL_PROTO(IntEnum): 95 # The ordering and enum values here matches original in 96 # https://github.com/NVIDIA/nccl/blob/0b083e52096c387bad7a5c5c65b26a9dca54de8c/src/include/devcomm.h#L28 97 # For difference between these protocols, see https://github.com/NVIDIA/nccl/issues/281#issuecomment-571816990 98 LL = 0 # Low-latency 99 # LL128 = 1 # Low-latency 128-byte 100 # SIMPLE = 2 101 102 103# Latencies in us 104# len(NCCL_ALGO) x len(NCCL_PROTO) 105# NOTE: use array instead of tensor to prevent incompatibility with fake mode 106baseLat = [ 107 # Tree 108 [ 109 6.8, # LL 110 ], 111 # Ring 112 [ 113 6.6, # LL 114 ], 115] 116 117# Latencies in us 118# len(NCCL_HW) x len(NCCL_ALGO) x len(NCCL_PROTO) 119hwLat = [ 120 # NVLINK 121 [ 122 [0.6], # Tree (LL) 123 [0.6], # Ring (LL) 124 ], 125 # PCI 126 [ 127 [1.0], # Tree (LL) 128 [1.0], # Ring (LL) 129 ], 130 # NET 131 [ 132 [5.0], # Tree (LL) 133 [2.7], # Ring (LL) 134 ], 135] 136 137 138# LL128 max BW per channel 139llMaxBws = [ 140 # Volta-N1/Intel-N2/Intel-N4 141 [ 142 39.0, 143 39.0, 144 20.4, 145 ], 146 # Ampere-N1/AMD-N2/AMD-N4 147 [ 148 87.7, 149 22.5, # avg of ring & tree 150 19.0, 151 ], 152 # Hopper-N1/AMD-N2/AMD-N4 153 [ 154 87.7, 155 22.5, # avg of ring & tree 156 19.0, 157 ], 158] 159 160 161def estimate_nccl_collective_runtime(node: ir.IRNode) -> float: 162 """ 163 Returns estimated NCCL collective runtime in nanoseconds (ns). 164 165 The following heuristics are copied from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc. 166 We aim to estimate the runtime as accurately as possible. 167 168 Assumptions: 169 - only ring algorithm (NCCL_ALGO_RING) is used 170 - only Low-Latency protocol (NCCL_PROTO_LL) is used, i.e. Simple or LL128 is not used 171 - 8 gpus per node # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info. 172 - collective is one of: allreduce, reducescatter, allgather 173 """ 174 tensor_storage_size_bytes = get_collective_input_size_bytes(node) 175 # Convert bytes to GB 176 tensor_storage_size_GB = tensor_storage_size_bytes / 1024 / 1024 / 1024 177 178 # Currently assumes each node has 8 gpus. And when >1 node is used, assumes each node uses all 8 gpus. 179 # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info. 180 num_gpus_per_node = 8 181 group_size = get_collective_group_size(node) 182 nNodes = math.ceil(group_size / num_gpus_per_node) 183 nRanks = group_size # this is total # of gpus globally that participate in this collective op 184 185 if nRanks <= 1: 186 return 0 187 188 # Assumes ring algorithm 189 nccl_algo = NCCL_ALGO.RING 190 nccl_proto = NCCL_PROTO.LL 191 coll = get_collective_type(node) 192 193 # =============== bandwidth computation =============== 194 # First compute bandwidth in GB/s; then at the end, convert it to GB/ns 195 196 bwIntra = torch._inductor.config.intra_node_bw 197 bwInter = torch._inductor.config.inter_node_bw 198 199 compCapIndex = get_gpu_type() 200 index2 = nNodes - 1 if nNodes <= 2 else 2 201 # LL: for single node, we look at GPU type; for multi-node, we look at CPU type 202 index1 = compCapIndex if nNodes == 1 else 0 203 llMaxBw = llMaxBws[index1][index2] 204 205 # NOTE: each step of ring algorithm is synchronized, 206 # and is bottlenecked by the slowest link which is the inter-node interconnect. 207 # hence when nNodes >= 2, bw is inter-node bandwidth. 208 # NOTE: the original code in https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc 209 # have this as `if nNodes <= 2` which seems wrong. Corrected it here. 210 bw = bwIntra if nNodes == 1 else bwInter 211 nChannels = 2 # Assume # channels is 2 212 busBw = nChannels * bw 213 214 # Various model refinements 215 busBw = min( 216 llMaxBw, 217 busBw 218 * (1.0 / 4.0 if (nNodes > 1 or coll == NCCL_COLL.ALL_REDUCE) else 1.0 / 3.0), 219 ) 220 221 if coll == NCCL_COLL.ALL_REDUCE: 222 nsteps = 2 * (nRanks - 1) 223 elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER): 224 nsteps = nRanks - 1 225 226 # Convert bus BW to algorithm BW (tensor bytes / algoBW = actual execution time) 227 ratio = (1.0 * nRanks) / nsteps # type: ignore[possibly-undefined] 228 bandwidth = busBw * ratio 229 # Convert GB/s to GB/ns 230 bandwidth_GB_per_ns = bandwidth / 1e9 231 232 # =============== latency computation =============== 233 intraHw = NCCL_HW.NVLINK 234 235 if coll == NCCL_COLL.ALL_REDUCE: 236 if nNodes > 1: 237 nInterSteps = 2 * nNodes 238 else: 239 nInterSteps = 0 240 elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER): 241 nInterSteps = nNodes - 1 242 243 # First compute latency in us; then at the end, convert it to ns 244 latency = baseLat[nccl_algo][nccl_proto] 245 intraLat = hwLat[intraHw][nccl_algo][nccl_proto] 246 interLat = hwLat[NCCL_HW.NET][nccl_algo][nccl_proto] 247 248 # Inter-node rings still have to launch nsteps * net overhead. 249 netOverhead = 0.0 250 if nNodes > 1: 251 netOverhead = 1.0 # getNetOverhead(comm); 252 intraLat = max(intraLat, netOverhead) 253 latency += (nsteps - nInterSteps) * intraLat + nInterSteps * interLat # type: ignore[possibly-undefined] 254 # Convert us to ns 255 latency_ns = latency * 1e3 256 257 # =============== final result =============== 258 transport_ns = tensor_storage_size_GB / bandwidth_GB_per_ns 259 return transport_ns + latency_ns 260 261 262################################################################################################################ 263# The above code and constants are adapted from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc # 264################################################################################################################ 265