• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Meta Platforms, Inc. and affiliates
2from typing import Dict, Tuple
3
4from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
5
6from . import _version
7from ._traverse import (
8    OBJ_PATH,
9    set_element,
10    STATE_DICT_ITEM,
11    traverse_state_dict,
12    traverse_state_dict_v_2_3,
13)
14
15
16"""
17TODO:
18Need to add ability to handle tuple, OrderedDict, NamedTuple.
19Update mappings from dict to a class.
20Change set_element to recreate the right type for tuple, OrderedDict, and NamedTuple.
21"""
22
23
24FLATTEN_MAPPING = Dict[str, OBJ_PATH]
25
26
27# TODO: Update Docstring for nested_dict.py
28def flatten_state_dict(
29    state_dict: STATE_DICT_TYPE,
30) -> Tuple[STATE_DICT_TYPE, FLATTEN_MAPPING]:
31    """
32    Flatten ``state_dict`` made of nested dicts and lists into a top level dictionary.
33
34    Use ``unflatten_state_dict`` to revert this process.
35    Returns:
36        A tuple with the flatten state_dict and a mapping from original to new state_dict.
37    N.B. The new keys are derived from the object paths, joined by dot.
38        For example: ``{ 'a': {'b':...}}`` results in the key `a.b`.
39    """
40    flattened: STATE_DICT_TYPE = {}
41    mappings: FLATTEN_MAPPING = {}
42
43    def flat_copy(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None:
44        new_fqn = ".".join(map(str, path))
45        if new_fqn in flattened:
46            raise ValueError(f"duplicated flatten key {new_fqn}")
47        flattened[new_fqn] = value
48        mappings[new_fqn] = path
49
50    # We started to flatten dictionary since v2.4. But in order to not break
51    # the checkpoints that were saved before v2.4, we need to keep the old
52    # traversal so that we can reconstruct those checkpoints.
53    use_v_2_3 = (
54        _version._derived_version is not None and _version._derived_version == "2_3"
55    )
56    if use_v_2_3:
57        traverse_state_dict_v_2_3(state_dict, flat_copy)
58    else:
59        traverse_state_dict(state_dict, flat_copy)
60    return flattened, mappings
61
62
63def unflatten_state_dict(
64    state_dict: STATE_DICT_TYPE, mapping: FLATTEN_MAPPING
65) -> STATE_DICT_TYPE:
66    """Restore the original nested state_dict according to ``mapping`` and the flattened ``state_dict``."""
67    nested: STATE_DICT_TYPE = {}
68    for key, value in state_dict.items():
69        set_element(nested, mapping[key], value)
70    return nested
71