• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates
2import dataclasses
3import logging
4from typing import Dict, List, TYPE_CHECKING
5
6from torch.distributed.checkpoint.planner import SavePlan
7
8
9if TYPE_CHECKING:
10    from torch.distributed.checkpoint.metadata import MetadataIndex
11
12__all__ = ["dedup_tensors"]
13
14
15def init_logger() -> logging.Logger:
16    logger = logging.getLogger(__name__)
17    level = logging.INFO
18    logger.setLevel(level)
19    console = logging.StreamHandler()
20    formatter = logging.Formatter(
21        "%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s"
22    )
23    console.setFormatter(formatter)
24    console.setLevel(level)
25    logger.addHandler(console)
26    logger.propagate = False
27    return logger
28
29
30logger = init_logger()
31
32
33# TODO add docstring for dedup_tensors
34def dedup_tensors(all_plans: List[SavePlan]) -> List[SavePlan]:
35    all_plans = list(all_plans)
36    key_to_plan: Dict[MetadataIndex, List[int]] = {}
37    for plan_idx, plan in enumerate(all_plans):
38        for write_item in plan.items:
39            key_to_plan.setdefault(write_item.index, []).append(plan_idx)
40
41    replicated_items = {k: v for k, v in key_to_plan.items() if len(v) > 1}
42
43    # Remove duplicates by always keeping the first entry.
44    # Compute the per-rank remove set.
45    plan_to_keys: Dict[int, List[MetadataIndex]] = {}
46    for key, plans in replicated_items.items():
47        for plan_idx in plans[1:]:
48            plan_to_keys.setdefault(plan_idx, []).append(key)
49    if len(plan_to_keys) > 0:
50        logger.info("Duplicate keys to remove: %s", plan_to_keys)
51
52    for plan_idx, keys in plan_to_keys.items():
53        key_set = set(keys)
54        # rewrite items and remove elements
55        new_items = [
56            write_item
57            for write_item in all_plans[plan_idx].items
58            if write_item.index not in key_set
59        ]
60        all_plans[plan_idx] = dataclasses.replace(all_plans[plan_idx], items=new_items)
61
62    return all_plans
63