1# Copyright (c) Meta Platforms, Inc. and affiliates 2import dataclasses 3import logging 4from typing import Dict, List, TYPE_CHECKING 5 6from torch.distributed.checkpoint.planner import SavePlan 7 8 9if TYPE_CHECKING: 10 from torch.distributed.checkpoint.metadata import MetadataIndex 11 12__all__ = ["dedup_tensors"] 13 14 15def init_logger() -> logging.Logger: 16 logger = logging.getLogger(__name__) 17 level = logging.INFO 18 logger.setLevel(level) 19 console = logging.StreamHandler() 20 formatter = logging.Formatter( 21 "%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s" 22 ) 23 console.setFormatter(formatter) 24 console.setLevel(level) 25 logger.addHandler(console) 26 logger.propagate = False 27 return logger 28 29 30logger = init_logger() 31 32 33# TODO add docstring for dedup_tensors 34def dedup_tensors(all_plans: List[SavePlan]) -> List[SavePlan]: 35 all_plans = list(all_plans) 36 key_to_plan: Dict[MetadataIndex, List[int]] = {} 37 for plan_idx, plan in enumerate(all_plans): 38 for write_item in plan.items: 39 key_to_plan.setdefault(write_item.index, []).append(plan_idx) 40 41 replicated_items = {k: v for k, v in key_to_plan.items() if len(v) > 1} 42 43 # Remove duplicates by always keeping the first entry. 44 # Compute the per-rank remove set. 45 plan_to_keys: Dict[int, List[MetadataIndex]] = {} 46 for key, plans in replicated_items.items(): 47 for plan_idx in plans[1:]: 48 plan_to_keys.setdefault(plan_idx, []).append(key) 49 if len(plan_to_keys) > 0: 50 logger.info("Duplicate keys to remove: %s", plan_to_keys) 51 52 for plan_idx, keys in plan_to_keys.items(): 53 key_set = set(keys) 54 # rewrite items and remove elements 55 new_items = [ 56 write_item 57 for write_item in all_plans[plan_idx].items 58 if write_item.index not in key_set 59 ] 60 all_plans[plan_idx] = dataclasses.replace(all_plans[plan_idx], items=new_items) 61 62 return all_plans 63