1# Copyright (c) Meta Platforms, Inc. and affiliates 2 3import torch 4import torch.distributed.tensor._ops # force import all built-in dtensor ops 5from torch.distributed.device_mesh import DeviceMesh, init_device_mesh # noqa: F401 6from torch.distributed.tensor._api import ( 7 distribute_module, 8 distribute_tensor, 9 DTensor, 10 empty, 11 full, 12 ones, 13 rand, 14 randn, 15 zeros, 16) 17from torch.distributed.tensor.placement_types import ( 18 Partial, 19 Placement, 20 Replicate, 21 Shard, 22) 23from torch.optim.optimizer import ( 24 _foreach_supported_types as _optim_foreach_supported_types, 25) 26from torch.utils._foreach_utils import ( 27 _foreach_supported_types as _util_foreach_supported_types, 28) 29 30 31# All public APIs from dtensor package 32__all__ = [ 33 "DTensor", 34 "distribute_tensor", 35 "distribute_module", 36 "Shard", 37 "Replicate", 38 "Partial", 39 "Placement", 40 "ones", 41 "empty", 42 "full", 43 "rand", 44 "randn", 45 "zeros", 46] 47 48 49# Append DTensor to the list of supported types for foreach implementation for optimizer 50# and clip_grad_norm_ so that we will try to use foreach over the for-loop implementation on CUDA. 51if DTensor not in _optim_foreach_supported_types: 52 _optim_foreach_supported_types.append(DTensor) 53 54if DTensor not in _util_foreach_supported_types: 55 _util_foreach_supported_types.append(DTensor) 56 57 58# Set namespace for exposed private names 59DTensor.__module__ = "torch.distributed.tensor" 60distribute_tensor.__module__ = "torch.distributed.tensor" 61distribute_module.__module__ = "torch.distributed.tensor" 62ones.__module__ = "torch.distributed.tensor" 63empty.__module__ = "torch.distributed.tensor" 64full.__module__ = "torch.distributed.tensor" 65rand.__module__ = "torch.distributed.tensor" 66randn.__module__ = "torch.distributed.tensor" 67zeros.__module__ = "torch.distributed.tensor" 68