• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates
2
3import torch
4import torch.distributed.tensor._ops  # force import all built-in dtensor ops
5from torch.distributed.device_mesh import DeviceMesh, init_device_mesh  # noqa: F401
6from torch.distributed.tensor._api import (
7    distribute_module,
8    distribute_tensor,
9    DTensor,
10    empty,
11    full,
12    ones,
13    rand,
14    randn,
15    zeros,
16)
17from torch.distributed.tensor.placement_types import (
18    Partial,
19    Placement,
20    Replicate,
21    Shard,
22)
23from torch.optim.optimizer import (
24    _foreach_supported_types as _optim_foreach_supported_types,
25)
26from torch.utils._foreach_utils import (
27    _foreach_supported_types as _util_foreach_supported_types,
28)
29
30
31# All public APIs from dtensor package
32__all__ = [
33    "DTensor",
34    "distribute_tensor",
35    "distribute_module",
36    "Shard",
37    "Replicate",
38    "Partial",
39    "Placement",
40    "ones",
41    "empty",
42    "full",
43    "rand",
44    "randn",
45    "zeros",
46]
47
48
49# Append DTensor to the list of supported types for foreach implementation for optimizer
50# and clip_grad_norm_ so that we will try to use foreach over the for-loop implementation on CUDA.
51if DTensor not in _optim_foreach_supported_types:
52    _optim_foreach_supported_types.append(DTensor)
53
54if DTensor not in _util_foreach_supported_types:
55    _util_foreach_supported_types.append(DTensor)
56
57
58# Set namespace for exposed private names
59DTensor.__module__ = "torch.distributed.tensor"
60distribute_tensor.__module__ = "torch.distributed.tensor"
61distribute_module.__module__ = "torch.distributed.tensor"
62ones.__module__ = "torch.distributed.tensor"
63empty.__module__ = "torch.distributed.tensor"
64full.__module__ = "torch.distributed.tensor"
65rand.__module__ = "torch.distributed.tensor"
66randn.__module__ = "torch.distributed.tensor"
67zeros.__module__ = "torch.distributed.tensor"
68