1# mypy: allow-untyped-defs 2# Copyright (c) Meta Platforms, Inc. and affiliates 3import contextlib 4from typing import cast, Dict, Optional, Tuple 5 6import torch 7import torch._prims_common as utils 8import torch.distributed._functional_collectives as funcol 9import torch.distributed.distributed_c10d as c10d 10from torch import Tensor 11from torch.distributed.device_mesh import DeviceMesh 12from torch.distributed.tensor import DTensor, Replicate, Shard 13from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta 14from torch.distributed.tensor._ops._embedding_ops import _MaskPartial 15from torch.distributed.tensor._ops._math_ops import ( 16 _skip_dim, 17 Reduction, 18 replicate_reduction_dims, 19) 20from torch.distributed.tensor.placement_types import Placement 21 22 23aten = torch.ops.aten 24 25 26__all__ = ["loss_parallel"] 27 28 29@contextlib.contextmanager 30def loss_parallel(): 31 """ 32 A context manager that enables loss parallelism, where efficient parallelized loss computation 33 can be performed when the input is sharded on the class dimension. Currently only the cross-entropy 34 loss is supported. 35 36 Within this context manager, one can use :func:`~torch.nn.functional.cross_entropy` or 37 :class:`~torch.nn.CrossEntropyLoss` as usual, with the following assumptions on the input parameters. 38 The corresponding ``backward()`` call, if any, also needs to happen under this context manager. 39 40 Args: 41 input (:class:`DTensor`): 42 Input logits. Assumed to be sharded on the class dimension. 43 target (Union[:class:`torch.Tensor`, :class:`DTensor`]): 44 Must be ground truth class indices (class probabilities currently not supported). 45 Assumed to be replicated across the ``DeviceMesh``. 46 weight (Union[:class:`torch.Tensor`, :class:`DTensor`], optional): 47 If given, assumed to be replicated across the ``DeviceMesh``. 48 label_smoothing: 49 Currently not supported. 50 51 Returns: 52 A replicated :class:`DTensor`. 53 54 Example: 55 A sharded DTensor is manually created here to showcase the usage. 56 In practice, it is usually the output of a TP module. 57 58 >>> # xdoctest: +SKIP("distributed") 59 >>> from torch.distributed.tensor.parallel import loss_parallel 60 >>> from torch.distributed.device_mesh import init_device_mesh 61 >>> ... 62 >>> device_mesh = init_device_mesh("cuda", (8,)) 63 >>> input = torch.randn(4, 16, device="cuda", requires_grad=True) 64 >>> dist_input = distribute_tensor(input, device_mesh, placements=[Shard(1)]) 65 >>> target = torch.randint(16, (4,), device="cuda") 66 >>> with loss_parallel(): 67 >>> loss = F.cross_entropy(dist_input, target, reduction="mean") 68 >>> loss.backward() 69 >>> ... 70 """ 71 _enable_custom_loss_ops() 72 73 yield 74 75 _disable_custom_loss_ops() 76 77 78# Currently only needs to support one dimensional DeviceMesh; in general return 79# the mesh_dim with placements[mesh_dim].is_shard(dim) 80def _find_all_reduce_mesh_dim(placements: Tuple[Placement, ...], dim: int) -> int: 81 if not len(placements) == 1: 82 raise ValueError( 83 "Currently loss_parallel() only supports input on one-dimensional DeviceMesh." 84 ) 85 if not placements[0].is_shard(dim): 86 raise ValueError( 87 f"loss_parallel() should be enabled only when the input tensor is sharded on dimension {dim}." 88 ) 89 return 0 90 91 92def _cast_to_dtensor( 93 tensor, placements: Tuple[Placement, ...], mesh: DeviceMesh 94) -> DTensor: 95 if isinstance(tensor, DTensor): 96 if tensor.placements == placements: 97 return tensor 98 else: 99 raise RuntimeError(f"Expected {placements} but got {tensor.placements}.") 100 elif isinstance(tensor, torch.Tensor): 101 return DTensor.from_local( 102 tensor, device_mesh=mesh, placements=placements, run_check=False 103 ) 104 else: 105 raise TypeError(f"Unsupported type {type(tensor)}") 106 107 108def _propagate_tensor_meta( 109 op_call: torch._ops.OpOverload, 110 args: Tuple[object, ...], 111 kwargs: Dict[str, object], 112) -> TensorMeta: 113 op_info = DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs) 114 tensor_meta = DTensor._op_dispatcher.sharding_propagator._propagate_tensor_meta( 115 op_info.schema 116 ) 117 if isinstance(tensor_meta, TensorMeta): 118 return tensor_meta 119 elif isinstance(tensor_meta, tuple): 120 return tensor_meta[0] 121 else: 122 raise RuntimeError(f"Unexpected tensor meta type: {type(tensor_meta)}.") 123 124 125# NOTE: The implementation follows torch._decomp.decomposition._log_softmax, 126# with all_reduce manually inserted to perform distributed computation. 127def _log_softmax(x, dim, half_to_float, mesh, mesh_dim): 128 x = x.contiguous() 129 if half_to_float: 130 assert x.dtype == torch.half 131 computation_dtype, result_dtype = utils.elementwise_dtypes( 132 x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 133 ) 134 x = x.to(computation_dtype) 135 if x.numel() == 0: 136 shifted = x 137 else: 138 x_max = torch.amax(x, dim, keepdim=True) 139 x_max = funcol.all_reduce( 140 x_max, reduceOp=c10d.ReduceOp.MAX.name, group=(mesh, mesh_dim) 141 ) 142 shifted = x - x_max 143 shifted_sumexp = torch.sum(torch.exp(shifted), dim, keepdim=True) 144 shifted_sumexp = funcol.all_reduce( 145 shifted_sumexp, reduceOp=c10d.ReduceOp.SUM.name, group=(mesh, mesh_dim) 146 ) 147 shifted_logsumexp = torch.log(shifted_sumexp) 148 result = shifted - shifted_logsumexp 149 if not half_to_float: 150 result = result.to(result_dtype) 151 return result 152 153 154def _log_softmax_handler( 155 op_call: torch._ops.OpOverload, 156 args: Tuple[object, ...], 157 kwargs: Dict[str, object], 158) -> object: 159 x = cast(DTensor, args[0]) 160 dim = cast(int, args[1]) 161 half_to_float = cast(bool, args[2]) 162 163 spec = x._spec 164 mesh_dim = _find_all_reduce_mesh_dim(spec.placements, dim) 165 166 output_tensor_meta = _propagate_tensor_meta(op_call, args, kwargs) 167 168 res = _log_softmax(x._local_tensor, dim, half_to_float, spec.mesh, mesh_dim) 169 170 res_spec = DTensorSpec( 171 spec.mesh, 172 spec.placements, 173 tensor_meta=output_tensor_meta, 174 ) 175 176 return DTensor( 177 res, 178 res_spec, 179 requires_grad=res.requires_grad, 180 ) 181 182 183# NOTE: As explained below at _nll_loss_and_log_softmax_backward, the 184# _log_softmax_backward_handler does not actually do any computation. 185def _log_softmax_backward_handler( 186 op_call: torch._ops.OpOverload, 187 args: Tuple[object, ...], 188 kwargs: Dict[str, object], 189) -> object: 190 grad_output = cast(DTensor, args[0]) 191 input_dtype = cast(torch.dtype, args[3]) 192 return grad_output.to(input_dtype) 193 194 195# NOTE: The implementation follows torch._decomp.decomposition._nll_loss_forward, 196# with customized communication inserted to perform distributed computation. 197def _nll_loss_forward( 198 x: Tensor, 199 target: Tensor, 200 weight: Optional[Tensor], 201 local_weight: Optional[Tensor], 202 reduction: int, 203 ignore_index: int, 204 input_shape: torch.Size, 205 channel_dim: int, 206 mesh: DeviceMesh, 207 mesh_dim: int, 208) -> Tuple[Tensor, Tensor]: 209 n_dims = x.dim() 210 channel_dim = 1 211 if n_dims < 2: 212 channel_dim = 0 213 214 def _weight_view(weight: Tensor) -> Tensor: 215 if n_dims > 1: 216 shape = [ 217 1, 218 ] * n_dims 219 shape[channel_dim] = weight.shape[0] 220 w = weight.view(shape) 221 else: 222 w = weight 223 return w 224 225 if weight is not None: 226 w = _weight_view(weight) 227 assert local_weight is not None 228 local_w = _weight_view(local_weight) 229 x = x * local_w 230 safe_target = torch.where(target != ignore_index, target, 0) 231 safe_target_ = safe_target.unsqueeze(channel_dim) 232 233 # The following code block is a distributed version of 234 # result = -torch.gather(self, channel_dim, safe_target_).squeeze(channel_dim) 235 partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=channel_dim) 236 safe_target_partial_ = partial_placement._partition_value( 237 safe_target_, mesh, mesh_dim 238 ) 239 result_partial = torch.gather(x, channel_dim, safe_target_partial_) 240 # an all_reduce happens here 241 result_reduced = partial_placement._reduce_value(result_partial, mesh, mesh_dim) 242 result = -result_reduced.squeeze(channel_dim) 243 244 result = torch.where(target != ignore_index, result, 0) 245 246 if reduction == Reduction.NONE.value and n_dims > 1: 247 total_weight = x.new_full((), 0.0) 248 return result, total_weight 249 250 if weight is not None: 251 new_shape = list(x.shape) 252 new_shape[channel_dim] = -1 253 w = w.expand(new_shape) 254 wsum = torch.gather(w, channel_dim, safe_target_).squeeze(channel_dim) 255 wsum = torch.where(target != ignore_index, wsum, 0) 256 total_weight = wsum.sum() 257 else: 258 total_weight = (target != ignore_index).sum().to(x) 259 260 # NOTE: this is correct only on 1D DeviceMesh; o/w additional 261 # all-reduce on result and total_weight is needed 262 if reduction == Reduction.SUM.value: 263 result = result.sum() 264 elif reduction == Reduction.MEAN.value: 265 result = result.sum() / total_weight 266 267 return result, total_weight 268 269 270def _nll_loss_forward_handler( 271 op_call: torch._ops.OpOverload, 272 args: Tuple[object, ...], 273 kwargs: Dict[str, object], 274) -> object: 275 x = cast(DTensor, args[0]) 276 target = args[1] 277 weight = args[2] 278 reduction = cast(int, args[3]) 279 ignore_index = cast(int, args[4]) 280 281 channel_dim = 1 if x.dim() >= 2 else 0 282 channel_dim_size = x.shape[channel_dim] 283 spec = x._spec 284 mesh_dim = _find_all_reduce_mesh_dim(spec.placements, channel_dim) 285 286 # Check user input: if target and weight are not DTensors, convert them to DTensors; 287 # if they are DTensors, check that they have the desired placements. 288 target_placements = _skip_dim( 289 replicate_reduction_dims(spec.placements, [channel_dim]), channel_dim 290 ) 291 all_replicate_placements = (Replicate(),) * spec.mesh.ndim 292 target = _cast_to_dtensor(target, target_placements, spec.mesh) 293 local_weight = None 294 if weight is not None: 295 weight = _cast_to_dtensor(weight, all_replicate_placements, spec.mesh) 296 # For local computation, both (replicated) weight and (sharded) local_weight 297 # are needed in _nll_loss_forward(). local_weight is generated here using 298 # DTensor API, without incurring any communication. 299 sharded_placements = [ 300 Shard(0) if i == mesh_dim else Replicate() for i in range(spec.mesh.ndim) 301 ] 302 local_weight = weight.redistribute(spec.mesh, sharded_placements)._local_tensor 303 assert local_weight.shape[0] == x._local_tensor.shape[channel_dim] 304 305 if reduction == Reduction.NONE.value: 306 output_placements = target_placements 307 else: 308 output_placements = all_replicate_placements 309 310 # tensor inputs to _propagate_tensor_meta need to be DTensors 311 args = list(args) 312 args[1], args[2] = target, weight 313 output_tensor_meta = _propagate_tensor_meta(op_call, tuple(args), kwargs) 314 315 result, total_weight = _nll_loss_forward( 316 x._local_tensor, 317 target._local_tensor, 318 weight._local_tensor if weight is not None else None, 319 local_weight, 320 reduction, 321 ignore_index, 322 x.shape, 323 channel_dim, 324 spec.mesh, 325 mesh_dim, 326 ) 327 out_spec = DTensorSpec(spec.mesh, output_placements, tensor_meta=output_tensor_meta) 328 329 return ( 330 DTensor( 331 result, 332 out_spec, 333 requires_grad=result.requires_grad, 334 ), 335 total_weight, 336 ) 337 338 339# NOTE: The backward computation of cross_entropy goes through two steps: 340# backward for nll_loss and then backward for log_softmax. In loss parallel, 341# the two steps are fused into the following function (called by _nll_loss_backward_handler) 342# to avoid communication when target contains class indices not class probabilities. 343# Also note that the _log_softmax_backward_handler does not perform computation. 344# The implementation resembles _nll_loss_backward and _log_softmax_backward_data 345# from torch._decomp.decomposition. 346def _nll_loss_and_log_softmax_backward( 347 grad_output: Tensor, 348 x: Tensor, 349 target: Tensor, 350 weight: Optional[Tensor], 351 reduction: int, 352 ignore_index: int, 353 total_weight: Tensor, 354 input_shape: torch.Size, 355 channel_dim: int, 356 mesh: DeviceMesh, 357 mesh_dim: int, 358) -> Tensor: 359 channel_dim = 0 if x.dim() < 2 else 1 360 if reduction == Reduction.MEAN.value: 361 grad_output = grad_output / total_weight 362 363 target = target.unsqueeze(channel_dim) 364 safe_target = torch.where(target != ignore_index, target, 0) 365 grad_input = torch.zeros_like(x) 366 367 # The following code block is a distributed version of 368 # grad_input = torch.scatter(grad_input, channel_dim, safe_target, -1.0) 369 partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=channel_dim) 370 safe_target = safe_target.squeeze(channel_dim).flatten() 371 masked_safe_target = partial_placement._partition_value(safe_target, mesh, mesh_dim) 372 # only update grad_input to -1 if not masked 373 assert partial_placement.mask_buffer.data is not None 374 grad_update = partial_placement.mask_buffer.data.to(grad_input.dtype) - 1.0 375 arange_1d = torch.arange( 376 masked_safe_target.shape[0], device=masked_safe_target.device 377 ) 378 # The first two cases with x.dim() <= 2 are for aten.nll_loss_backward.default; 379 # the last case is for aten.nll_loss2d_backward.default. 380 if x.dim() == 1: 381 grad_input[masked_safe_target] = grad_update 382 elif x.dim() == 2: 383 grad_input[arange_1d, masked_safe_target] = grad_update 384 else: 385 grad_input_t = grad_input.transpose(channel_dim, -1) 386 intermidate_shape = grad_input_t.shape 387 grad_input_2d = grad_input_t.reshape(-1, x.shape[channel_dim]) 388 grad_input_2d[arange_1d, masked_safe_target] = grad_update 389 grad_input = grad_input_2d.view(intermidate_shape).transpose(channel_dim, -1) 390 391 if grad_input.dim() > grad_output.dim() > 0: 392 grad_output = grad_output.unsqueeze(channel_dim) 393 394 if weight is not None: 395 new_shape = [1 for _ in range(x.dim())] 396 new_shape[channel_dim] = weight.shape[0] 397 weight = weight.reshape(new_shape) 398 # In order for fused computation to work, the following line is rewritten. 399 # grad_output = grad_output * weight 400 new_shape = list(x.shape) 401 new_shape[channel_dim] = -1 402 w = weight.expand(new_shape) 403 w_target = torch.gather(w, channel_dim, target) 404 grad_output = grad_output * w_target 405 406 grad_output = torch.where(target != ignore_index, grad_output, 0) 407 408 # NOTE: Instead of directly returning the grad_input as grad_output for log_softmax, 409 # here we perform backward computation for log_softmax altogether to avoid the 410 # otherwise extra all_gather communication. 411 # return grad_input * grad_output 412 return (grad_input + torch.exp(x)) * grad_output 413 414 415def _nll_loss_backward_handler( 416 op_call: torch._ops.OpOverload, 417 args: Tuple[object, ...], 418 kwargs: Dict[str, object], 419) -> object: 420 grad_output = cast(DTensor, args[0]) 421 x = cast(DTensor, args[1]) 422 target = args[2] 423 weight = args[3] 424 reduction = cast(int, args[4]) 425 ignore_index = cast(int, args[5]) 426 total_weight = cast(Tensor, args[6]) 427 428 channel_dim = 1 if x.dim() >= 2 else 0 429 spec = x._spec 430 mesh_dim = _find_all_reduce_mesh_dim(spec.placements, channel_dim) 431 432 # if target and weight are not DTensors, convert them to DTensors 433 target_placements = _skip_dim( 434 replicate_reduction_dims(spec.placements, [channel_dim]), channel_dim 435 ) 436 all_replicate_placements = (Replicate(),) * spec.mesh.ndim 437 target = _cast_to_dtensor(target, target_placements, spec.mesh) 438 if weight is not None: 439 weight = _cast_to_dtensor(weight, all_replicate_placements, spec.mesh) 440 441 # tensor inputs to _propagate_tensor_meta need to be DTensors 442 args = list(args) 443 args[2], args[3] = target, weight 444 args[6] = _cast_to_dtensor(total_weight, all_replicate_placements, spec.mesh) 445 output_tensor_meta = _propagate_tensor_meta(op_call, tuple(args), kwargs) 446 447 result = _nll_loss_and_log_softmax_backward( 448 grad_output._local_tensor, 449 x._local_tensor, 450 target._local_tensor, 451 weight._local_tensor if weight is not None else None, 452 reduction, 453 ignore_index, 454 total_weight, 455 x.shape, 456 channel_dim, 457 spec.mesh, 458 mesh_dim, 459 ) 460 # the output sharding is the same as input sharding: Shard(channel_dim) on mesh_dim 461 out_spec = DTensorSpec( 462 spec.mesh, 463 spec.placements, 464 tensor_meta=output_tensor_meta, 465 ) 466 467 return DTensor( 468 result, 469 out_spec, 470 requires_grad=result.requires_grad, 471 ) 472 473 474customized_loss_ops = { 475 aten._log_softmax.default: _log_softmax_handler, 476 aten._log_softmax_backward_data.default: _log_softmax_backward_handler, 477 aten.nll_loss_forward.default: _nll_loss_forward_handler, 478 aten.nll_loss2d_forward.default: _nll_loss_forward_handler, 479 aten.nll_loss_backward.default: _nll_loss_backward_handler, 480 aten.nll_loss2d_backward.default: _nll_loss_backward_handler, 481} 482 483 484def _enable_custom_loss_ops(): 485 DTensor._op_dispatcher._custom_op_handlers.update(customized_loss_ops) 486 487 488def _disable_custom_loss_ops(): 489 for custom_op in customized_loss_ops: 490 DTensor._op_dispatcher._custom_op_handlers.pop(custom_op) 491