• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2# Copyright (c) Meta Platforms, Inc. and affiliates
3import contextlib
4from typing import cast, Dict, Optional, Tuple
5
6import torch
7import torch._prims_common as utils
8import torch.distributed._functional_collectives as funcol
9import torch.distributed.distributed_c10d as c10d
10from torch import Tensor
11from torch.distributed.device_mesh import DeviceMesh
12from torch.distributed.tensor import DTensor, Replicate, Shard
13from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
14from torch.distributed.tensor._ops._embedding_ops import _MaskPartial
15from torch.distributed.tensor._ops._math_ops import (
16    _skip_dim,
17    Reduction,
18    replicate_reduction_dims,
19)
20from torch.distributed.tensor.placement_types import Placement
21
22
23aten = torch.ops.aten
24
25
26__all__ = ["loss_parallel"]
27
28
29@contextlib.contextmanager
30def loss_parallel():
31    """
32    A context manager that enables loss parallelism, where efficient parallelized loss computation
33    can be performed when the input is sharded on the class dimension. Currently only the cross-entropy
34    loss is supported.
35
36    Within this context manager, one can use :func:`~torch.nn.functional.cross_entropy` or
37    :class:`~torch.nn.CrossEntropyLoss` as usual, with the following assumptions on the input parameters.
38    The corresponding ``backward()`` call, if any, also needs to happen under this context manager.
39
40    Args:
41        input (:class:`DTensor`):
42            Input logits. Assumed to be sharded on the class dimension.
43        target (Union[:class:`torch.Tensor`, :class:`DTensor`]):
44            Must be ground truth class indices (class probabilities currently not supported).
45            Assumed to be replicated across the ``DeviceMesh``.
46        weight (Union[:class:`torch.Tensor`, :class:`DTensor`], optional):
47            If given, assumed to be replicated across the ``DeviceMesh``.
48        label_smoothing:
49            Currently not supported.
50
51    Returns:
52        A replicated :class:`DTensor`.
53
54    Example:
55        A sharded DTensor is manually created here to showcase the usage.
56        In practice, it is usually the output of a TP module.
57
58        >>> # xdoctest: +SKIP("distributed")
59        >>> from torch.distributed.tensor.parallel import loss_parallel
60        >>> from torch.distributed.device_mesh import init_device_mesh
61        >>> ...
62        >>> device_mesh = init_device_mesh("cuda", (8,))
63        >>> input = torch.randn(4, 16, device="cuda", requires_grad=True)
64        >>> dist_input = distribute_tensor(input, device_mesh, placements=[Shard(1)])
65        >>> target = torch.randint(16, (4,), device="cuda")
66        >>> with loss_parallel():
67        >>>     loss = F.cross_entropy(dist_input, target, reduction="mean")
68        >>>     loss.backward()
69        >>> ...
70    """
71    _enable_custom_loss_ops()
72
73    yield
74
75    _disable_custom_loss_ops()
76
77
78# Currently only needs to support one dimensional DeviceMesh; in general return
79# the mesh_dim with placements[mesh_dim].is_shard(dim)
80def _find_all_reduce_mesh_dim(placements: Tuple[Placement, ...], dim: int) -> int:
81    if not len(placements) == 1:
82        raise ValueError(
83            "Currently loss_parallel() only supports input on one-dimensional DeviceMesh."
84        )
85    if not placements[0].is_shard(dim):
86        raise ValueError(
87            f"loss_parallel() should be enabled only when the input tensor is sharded on dimension {dim}."
88        )
89    return 0
90
91
92def _cast_to_dtensor(
93    tensor, placements: Tuple[Placement, ...], mesh: DeviceMesh
94) -> DTensor:
95    if isinstance(tensor, DTensor):
96        if tensor.placements == placements:
97            return tensor
98        else:
99            raise RuntimeError(f"Expected {placements} but got {tensor.placements}.")
100    elif isinstance(tensor, torch.Tensor):
101        return DTensor.from_local(
102            tensor, device_mesh=mesh, placements=placements, run_check=False
103        )
104    else:
105        raise TypeError(f"Unsupported type {type(tensor)}")
106
107
108def _propagate_tensor_meta(
109    op_call: torch._ops.OpOverload,
110    args: Tuple[object, ...],
111    kwargs: Dict[str, object],
112) -> TensorMeta:
113    op_info = DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)
114    tensor_meta = DTensor._op_dispatcher.sharding_propagator._propagate_tensor_meta(
115        op_info.schema
116    )
117    if isinstance(tensor_meta, TensorMeta):
118        return tensor_meta
119    elif isinstance(tensor_meta, tuple):
120        return tensor_meta[0]
121    else:
122        raise RuntimeError(f"Unexpected tensor meta type: {type(tensor_meta)}.")
123
124
125# NOTE: The implementation follows torch._decomp.decomposition._log_softmax,
126# with all_reduce manually inserted to perform distributed computation.
127def _log_softmax(x, dim, half_to_float, mesh, mesh_dim):
128    x = x.contiguous()
129    if half_to_float:
130        assert x.dtype == torch.half
131    computation_dtype, result_dtype = utils.elementwise_dtypes(
132        x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
133    )
134    x = x.to(computation_dtype)
135    if x.numel() == 0:
136        shifted = x
137    else:
138        x_max = torch.amax(x, dim, keepdim=True)
139        x_max = funcol.all_reduce(
140            x_max, reduceOp=c10d.ReduceOp.MAX.name, group=(mesh, mesh_dim)
141        )
142        shifted = x - x_max
143    shifted_sumexp = torch.sum(torch.exp(shifted), dim, keepdim=True)
144    shifted_sumexp = funcol.all_reduce(
145        shifted_sumexp, reduceOp=c10d.ReduceOp.SUM.name, group=(mesh, mesh_dim)
146    )
147    shifted_logsumexp = torch.log(shifted_sumexp)
148    result = shifted - shifted_logsumexp
149    if not half_to_float:
150        result = result.to(result_dtype)
151    return result
152
153
154def _log_softmax_handler(
155    op_call: torch._ops.OpOverload,
156    args: Tuple[object, ...],
157    kwargs: Dict[str, object],
158) -> object:
159    x = cast(DTensor, args[0])
160    dim = cast(int, args[1])
161    half_to_float = cast(bool, args[2])
162
163    spec = x._spec
164    mesh_dim = _find_all_reduce_mesh_dim(spec.placements, dim)
165
166    output_tensor_meta = _propagate_tensor_meta(op_call, args, kwargs)
167
168    res = _log_softmax(x._local_tensor, dim, half_to_float, spec.mesh, mesh_dim)
169
170    res_spec = DTensorSpec(
171        spec.mesh,
172        spec.placements,
173        tensor_meta=output_tensor_meta,
174    )
175
176    return DTensor(
177        res,
178        res_spec,
179        requires_grad=res.requires_grad,
180    )
181
182
183# NOTE: As explained below at _nll_loss_and_log_softmax_backward, the
184# _log_softmax_backward_handler does not actually do any computation.
185def _log_softmax_backward_handler(
186    op_call: torch._ops.OpOverload,
187    args: Tuple[object, ...],
188    kwargs: Dict[str, object],
189) -> object:
190    grad_output = cast(DTensor, args[0])
191    input_dtype = cast(torch.dtype, args[3])
192    return grad_output.to(input_dtype)
193
194
195# NOTE: The implementation follows torch._decomp.decomposition._nll_loss_forward,
196# with customized communication inserted to perform distributed computation.
197def _nll_loss_forward(
198    x: Tensor,
199    target: Tensor,
200    weight: Optional[Tensor],
201    local_weight: Optional[Tensor],
202    reduction: int,
203    ignore_index: int,
204    input_shape: torch.Size,
205    channel_dim: int,
206    mesh: DeviceMesh,
207    mesh_dim: int,
208) -> Tuple[Tensor, Tensor]:
209    n_dims = x.dim()
210    channel_dim = 1
211    if n_dims < 2:
212        channel_dim = 0
213
214    def _weight_view(weight: Tensor) -> Tensor:
215        if n_dims > 1:
216            shape = [
217                1,
218            ] * n_dims
219            shape[channel_dim] = weight.shape[0]
220            w = weight.view(shape)
221        else:
222            w = weight
223        return w
224
225    if weight is not None:
226        w = _weight_view(weight)
227        assert local_weight is not None
228        local_w = _weight_view(local_weight)
229        x = x * local_w
230    safe_target = torch.where(target != ignore_index, target, 0)
231    safe_target_ = safe_target.unsqueeze(channel_dim)
232
233    # The following code block is a distributed version of
234    # result = -torch.gather(self, channel_dim, safe_target_).squeeze(channel_dim)
235    partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=channel_dim)
236    safe_target_partial_ = partial_placement._partition_value(
237        safe_target_, mesh, mesh_dim
238    )
239    result_partial = torch.gather(x, channel_dim, safe_target_partial_)
240    # an all_reduce happens here
241    result_reduced = partial_placement._reduce_value(result_partial, mesh, mesh_dim)
242    result = -result_reduced.squeeze(channel_dim)
243
244    result = torch.where(target != ignore_index, result, 0)
245
246    if reduction == Reduction.NONE.value and n_dims > 1:
247        total_weight = x.new_full((), 0.0)
248        return result, total_weight
249
250    if weight is not None:
251        new_shape = list(x.shape)
252        new_shape[channel_dim] = -1
253        w = w.expand(new_shape)
254        wsum = torch.gather(w, channel_dim, safe_target_).squeeze(channel_dim)
255        wsum = torch.where(target != ignore_index, wsum, 0)
256        total_weight = wsum.sum()
257    else:
258        total_weight = (target != ignore_index).sum().to(x)
259
260    # NOTE: this is correct only on 1D DeviceMesh; o/w additional
261    #       all-reduce on result and total_weight is needed
262    if reduction == Reduction.SUM.value:
263        result = result.sum()
264    elif reduction == Reduction.MEAN.value:
265        result = result.sum() / total_weight
266
267    return result, total_weight
268
269
270def _nll_loss_forward_handler(
271    op_call: torch._ops.OpOverload,
272    args: Tuple[object, ...],
273    kwargs: Dict[str, object],
274) -> object:
275    x = cast(DTensor, args[0])
276    target = args[1]
277    weight = args[2]
278    reduction = cast(int, args[3])
279    ignore_index = cast(int, args[4])
280
281    channel_dim = 1 if x.dim() >= 2 else 0
282    channel_dim_size = x.shape[channel_dim]
283    spec = x._spec
284    mesh_dim = _find_all_reduce_mesh_dim(spec.placements, channel_dim)
285
286    # Check user input: if target and weight are not DTensors, convert them to DTensors;
287    # if they are DTensors, check that they have the desired placements.
288    target_placements = _skip_dim(
289        replicate_reduction_dims(spec.placements, [channel_dim]), channel_dim
290    )
291    all_replicate_placements = (Replicate(),) * spec.mesh.ndim
292    target = _cast_to_dtensor(target, target_placements, spec.mesh)
293    local_weight = None
294    if weight is not None:
295        weight = _cast_to_dtensor(weight, all_replicate_placements, spec.mesh)
296        # For local computation, both (replicated) weight and (sharded) local_weight
297        # are needed in _nll_loss_forward(). local_weight is generated here using
298        # DTensor API, without incurring any communication.
299        sharded_placements = [
300            Shard(0) if i == mesh_dim else Replicate() for i in range(spec.mesh.ndim)
301        ]
302        local_weight = weight.redistribute(spec.mesh, sharded_placements)._local_tensor
303        assert local_weight.shape[0] == x._local_tensor.shape[channel_dim]
304
305    if reduction == Reduction.NONE.value:
306        output_placements = target_placements
307    else:
308        output_placements = all_replicate_placements
309
310    # tensor inputs to _propagate_tensor_meta need to be DTensors
311    args = list(args)
312    args[1], args[2] = target, weight
313    output_tensor_meta = _propagate_tensor_meta(op_call, tuple(args), kwargs)
314
315    result, total_weight = _nll_loss_forward(
316        x._local_tensor,
317        target._local_tensor,
318        weight._local_tensor if weight is not None else None,
319        local_weight,
320        reduction,
321        ignore_index,
322        x.shape,
323        channel_dim,
324        spec.mesh,
325        mesh_dim,
326    )
327    out_spec = DTensorSpec(spec.mesh, output_placements, tensor_meta=output_tensor_meta)
328
329    return (
330        DTensor(
331            result,
332            out_spec,
333            requires_grad=result.requires_grad,
334        ),
335        total_weight,
336    )
337
338
339# NOTE: The backward computation of cross_entropy goes through two steps:
340# backward for nll_loss and then backward for log_softmax. In loss parallel,
341# the two steps are fused into the following function (called by _nll_loss_backward_handler)
342# to avoid communication when target contains class indices not class probabilities.
343# Also note that the _log_softmax_backward_handler does not perform computation.
344# The implementation resembles _nll_loss_backward and _log_softmax_backward_data
345# from torch._decomp.decomposition.
346def _nll_loss_and_log_softmax_backward(
347    grad_output: Tensor,
348    x: Tensor,
349    target: Tensor,
350    weight: Optional[Tensor],
351    reduction: int,
352    ignore_index: int,
353    total_weight: Tensor,
354    input_shape: torch.Size,
355    channel_dim: int,
356    mesh: DeviceMesh,
357    mesh_dim: int,
358) -> Tensor:
359    channel_dim = 0 if x.dim() < 2 else 1
360    if reduction == Reduction.MEAN.value:
361        grad_output = grad_output / total_weight
362
363    target = target.unsqueeze(channel_dim)
364    safe_target = torch.where(target != ignore_index, target, 0)
365    grad_input = torch.zeros_like(x)
366
367    # The following code block is a distributed version of
368    # grad_input = torch.scatter(grad_input, channel_dim, safe_target, -1.0)
369    partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=channel_dim)
370    safe_target = safe_target.squeeze(channel_dim).flatten()
371    masked_safe_target = partial_placement._partition_value(safe_target, mesh, mesh_dim)
372    # only update grad_input to -1 if not masked
373    assert partial_placement.mask_buffer.data is not None
374    grad_update = partial_placement.mask_buffer.data.to(grad_input.dtype) - 1.0
375    arange_1d = torch.arange(
376        masked_safe_target.shape[0], device=masked_safe_target.device
377    )
378    # The first two cases with x.dim() <= 2 are for aten.nll_loss_backward.default;
379    # the last case is for aten.nll_loss2d_backward.default.
380    if x.dim() == 1:
381        grad_input[masked_safe_target] = grad_update
382    elif x.dim() == 2:
383        grad_input[arange_1d, masked_safe_target] = grad_update
384    else:
385        grad_input_t = grad_input.transpose(channel_dim, -1)
386        intermidate_shape = grad_input_t.shape
387        grad_input_2d = grad_input_t.reshape(-1, x.shape[channel_dim])
388        grad_input_2d[arange_1d, masked_safe_target] = grad_update
389        grad_input = grad_input_2d.view(intermidate_shape).transpose(channel_dim, -1)
390
391    if grad_input.dim() > grad_output.dim() > 0:
392        grad_output = grad_output.unsqueeze(channel_dim)
393
394    if weight is not None:
395        new_shape = [1 for _ in range(x.dim())]
396        new_shape[channel_dim] = weight.shape[0]
397        weight = weight.reshape(new_shape)
398        # In order for fused computation to work, the following line is rewritten.
399        # grad_output = grad_output * weight
400        new_shape = list(x.shape)
401        new_shape[channel_dim] = -1
402        w = weight.expand(new_shape)
403        w_target = torch.gather(w, channel_dim, target)
404        grad_output = grad_output * w_target
405
406    grad_output = torch.where(target != ignore_index, grad_output, 0)
407
408    # NOTE: Instead of directly returning the grad_input as grad_output for log_softmax,
409    # here we perform backward computation for log_softmax altogether to avoid the
410    # otherwise extra all_gather communication.
411    # return grad_input * grad_output
412    return (grad_input + torch.exp(x)) * grad_output
413
414
415def _nll_loss_backward_handler(
416    op_call: torch._ops.OpOverload,
417    args: Tuple[object, ...],
418    kwargs: Dict[str, object],
419) -> object:
420    grad_output = cast(DTensor, args[0])
421    x = cast(DTensor, args[1])
422    target = args[2]
423    weight = args[3]
424    reduction = cast(int, args[4])
425    ignore_index = cast(int, args[5])
426    total_weight = cast(Tensor, args[6])
427
428    channel_dim = 1 if x.dim() >= 2 else 0
429    spec = x._spec
430    mesh_dim = _find_all_reduce_mesh_dim(spec.placements, channel_dim)
431
432    # if target and weight are not DTensors, convert them to DTensors
433    target_placements = _skip_dim(
434        replicate_reduction_dims(spec.placements, [channel_dim]), channel_dim
435    )
436    all_replicate_placements = (Replicate(),) * spec.mesh.ndim
437    target = _cast_to_dtensor(target, target_placements, spec.mesh)
438    if weight is not None:
439        weight = _cast_to_dtensor(weight, all_replicate_placements, spec.mesh)
440
441    # tensor inputs to _propagate_tensor_meta need to be DTensors
442    args = list(args)
443    args[2], args[3] = target, weight
444    args[6] = _cast_to_dtensor(total_weight, all_replicate_placements, spec.mesh)
445    output_tensor_meta = _propagate_tensor_meta(op_call, tuple(args), kwargs)
446
447    result = _nll_loss_and_log_softmax_backward(
448        grad_output._local_tensor,
449        x._local_tensor,
450        target._local_tensor,
451        weight._local_tensor if weight is not None else None,
452        reduction,
453        ignore_index,
454        total_weight,
455        x.shape,
456        channel_dim,
457        spec.mesh,
458        mesh_dim,
459    )
460    # the output sharding is the same as input sharding: Shard(channel_dim) on mesh_dim
461    out_spec = DTensorSpec(
462        spec.mesh,
463        spec.placements,
464        tensor_meta=output_tensor_meta,
465    )
466
467    return DTensor(
468        result,
469        out_spec,
470        requires_grad=result.requires_grad,
471    )
472
473
474customized_loss_ops = {
475    aten._log_softmax.default: _log_softmax_handler,
476    aten._log_softmax_backward_data.default: _log_softmax_backward_handler,
477    aten.nll_loss_forward.default: _nll_loss_forward_handler,
478    aten.nll_loss2d_forward.default: _nll_loss_forward_handler,
479    aten.nll_loss_backward.default: _nll_loss_backward_handler,
480    aten.nll_loss2d_backward.default: _nll_loss_backward_handler,
481}
482
483
484def _enable_custom_loss_ops():
485    DTensor._op_dispatcher._custom_op_handlers.update(customized_loss_ops)
486
487
488def _disable_custom_loss_ops():
489    for custom_op in customized_loss_ops:
490        DTensor._op_dispatcher._custom_op_handlers.pop(custom_op)
491