1# mypy: allow-untyped-defs 2 3# Copyright (c) Meta Platforms, Inc. and affiliates 4 5import os 6import shutil 7import tempfile 8from functools import wraps 9from typing import Any, Callable, Dict, Optional, Tuple 10 11import torch.distributed as dist 12 13 14def with_temp_dir( 15 func: Optional[Callable] = None, 16) -> Optional[Callable]: 17 """ 18 Wrapper to initialize temp directory for distributed checkpoint. 19 """ 20 assert func is not None 21 22 @wraps(func) 23 def wrapper(self, *args: Tuple[object], **kwargs: Dict[str, Any]) -> None: 24 if dist.is_initialized(): 25 # Only create temp_dir when rank is 0 26 if dist.get_rank() == 0: 27 temp_dir = tempfile.mkdtemp() 28 print(f"Using temp directory: {temp_dir}") 29 else: 30 temp_dir = "" 31 object_list = [temp_dir] 32 33 # Broadcast temp_dir to all the other ranks 34 os.sync() 35 dist.broadcast_object_list(object_list) 36 self.temp_dir = object_list[0] 37 os.sync() 38 else: 39 temp_dir = tempfile.mkdtemp() 40 print(f"No process group initialized, using temp directory: {temp_dir}") 41 self.temp_dir = temp_dir 42 43 try: 44 func(self, *args, **kwargs) 45 finally: 46 if dist.is_initialized() and dist.get_rank() == 0: 47 shutil.rmtree(self.temp_dir, ignore_errors=True) 48 else: 49 shutil.rmtree(self.temp_dir, ignore_errors=True) 50 51 return wrapper 52