1# Copyright (c) Meta Platforms, Inc. and affiliates 2import dataclasses 3from collections import defaultdict 4from typing import Dict, List, Set, TYPE_CHECKING 5 6from torch.distributed.checkpoint.planner import SavePlan, WriteItem 7 8 9if TYPE_CHECKING: 10 from torch.distributed.checkpoint.metadata import MetadataIndex 11 12__all__ = ["dedup_save_plans"] 13 14 15def dedup_save_plans( 16 all_plans: List[SavePlan], 17 save_to_lowest_rank: bool = False, 18) -> List[SavePlan]: 19 """ 20 Removes duplicate entries from appearing on multiple SavePlans. For each duplicate across 21 a set of SavePlans, only the smallest SavePlan in terms of planned storage keeps the entry. 22 """ 23 24 write_item_to_plan_indices: Dict[MetadataIndex, Set[int]] = defaultdict(set) 25 write_item_idx_to_write_item: Dict[MetadataIndex, WriteItem] = {} 26 for plan_idx, plan in enumerate(all_plans): 27 for write_item in plan.items: 28 # map each write item to its plan 29 write_item_to_plan_indices[write_item.index].add(plan_idx) 30 write_item_idx_to_write_item[write_item.index] = write_item 31 32 # put item in the plan with the smallest size and remove it from the other plan_indices 33 to_remove: List[Set] = [set() for _ in range(len(all_plans))] 34 plan_to_size = [0] * len(all_plans) 35 for write_item_idx, plan_indices in write_item_to_plan_indices.items(): 36 if save_to_lowest_rank: 37 select_plan_idx = min(plan_indices) 38 else: 39 select_plan_idx = min( 40 plan_indices, key=lambda plan_idx: plan_to_size[plan_idx] 41 ) 42 43 write_item = write_item_idx_to_write_item[write_item_idx] 44 # essentially ignores the storage size of anything that is not a tensor, since 45 # we don't know how much storage they represent 46 plan_to_size[select_plan_idx] += write_item.tensor_storage_size() or 1 47 48 plan_indices.remove(select_plan_idx) 49 for plan_idx in plan_indices: 50 to_remove[plan_idx].add(write_item_idx) 51 52 for plan_idx, remove_set in enumerate(to_remove): 53 new_items = [ 54 write_item 55 for write_item in all_plans[plan_idx].items 56 if write_item.index not in remove_set 57 ] 58 all_plans[plan_idx] = dataclasses.replace(all_plans[plan_idx], items=new_items) 59 60 return all_plans 61