• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2
3# Copyright (c) Meta Platforms, Inc. and affiliates
4
5import os
6import shutil
7import tempfile
8from functools import wraps
9from typing import Any, Callable, Dict, Optional, Tuple
10
11import torch.distributed as dist
12
13
14def with_temp_dir(
15    func: Optional[Callable] = None,
16) -> Optional[Callable]:
17    """
18    Wrapper to initialize temp directory for distributed checkpoint.
19    """
20    assert func is not None
21
22    @wraps(func)
23    def wrapper(self, *args: Tuple[object], **kwargs: Dict[str, Any]) -> None:
24        if dist.is_initialized():
25            # Only create temp_dir when rank is 0
26            if dist.get_rank() == 0:
27                temp_dir = tempfile.mkdtemp()
28                print(f"Using temp directory: {temp_dir}")
29            else:
30                temp_dir = ""
31            object_list = [temp_dir]
32
33            # Broadcast temp_dir to all the other ranks
34            os.sync()
35            dist.broadcast_object_list(object_list)
36            self.temp_dir = object_list[0]
37            os.sync()
38        else:
39            temp_dir = tempfile.mkdtemp()
40            print(f"No process group initialized, using temp directory: {temp_dir}")
41            self.temp_dir = temp_dir
42
43        try:
44            func(self, *args, **kwargs)
45        finally:
46            if dist.is_initialized() and dist.get_rank() == 0:
47                shutil.rmtree(self.temp_dir, ignore_errors=True)
48            else:
49                shutil.rmtree(self.temp_dir, ignore_errors=True)
50
51    return wrapper
52