• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates
2import dataclasses
3from collections import defaultdict
4from typing import Dict, List, Set, TYPE_CHECKING
5
6from torch.distributed.checkpoint.planner import SavePlan, WriteItem
7
8
9if TYPE_CHECKING:
10    from torch.distributed.checkpoint.metadata import MetadataIndex
11
12__all__ = ["dedup_save_plans"]
13
14
15def dedup_save_plans(
16    all_plans: List[SavePlan],
17    save_to_lowest_rank: bool = False,
18) -> List[SavePlan]:
19    """
20    Removes duplicate entries from appearing on multiple SavePlans. For each duplicate across
21    a set of SavePlans, only the smallest SavePlan in terms of planned storage keeps the entry.
22    """
23
24    write_item_to_plan_indices: Dict[MetadataIndex, Set[int]] = defaultdict(set)
25    write_item_idx_to_write_item: Dict[MetadataIndex, WriteItem] = {}
26    for plan_idx, plan in enumerate(all_plans):
27        for write_item in plan.items:
28            # map each write item to its plan
29            write_item_to_plan_indices[write_item.index].add(plan_idx)
30            write_item_idx_to_write_item[write_item.index] = write_item
31
32    # put item in the plan with the smallest size and remove it from the other plan_indices
33    to_remove: List[Set] = [set() for _ in range(len(all_plans))]
34    plan_to_size = [0] * len(all_plans)
35    for write_item_idx, plan_indices in write_item_to_plan_indices.items():
36        if save_to_lowest_rank:
37            select_plan_idx = min(plan_indices)
38        else:
39            select_plan_idx = min(
40                plan_indices, key=lambda plan_idx: plan_to_size[plan_idx]
41            )
42
43        write_item = write_item_idx_to_write_item[write_item_idx]
44        # essentially ignores the storage size of anything that is not a tensor, since
45        # we don't know how much storage they represent
46        plan_to_size[select_plan_idx] += write_item.tensor_storage_size() or 1
47
48        plan_indices.remove(select_plan_idx)
49        for plan_idx in plan_indices:
50            to_remove[plan_idx].add(write_item_idx)
51
52    for plan_idx, remove_set in enumerate(to_remove):
53        new_items = [
54            write_item
55            for write_item in all_plans[plan_idx].items
56            if write_item.index not in remove_set
57        ]
58        all_plans[plan_idx] = dataclasses.replace(all_plans[plan_idx], items=new_items)
59
60    return all_plans
61