1# Copyright (c) Meta Platforms, Inc. and affiliates 2from typing import Dict, Tuple 3 4from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE 5 6from . import _version 7from ._traverse import ( 8 OBJ_PATH, 9 set_element, 10 STATE_DICT_ITEM, 11 traverse_state_dict, 12 traverse_state_dict_v_2_3, 13) 14 15 16""" 17TODO: 18Need to add ability to handle tuple, OrderedDict, NamedTuple. 19Update mappings from dict to a class. 20Change set_element to recreate the right type for tuple, OrderedDict, and NamedTuple. 21""" 22 23 24FLATTEN_MAPPING = Dict[str, OBJ_PATH] 25 26 27# TODO: Update Docstring for nested_dict.py 28def flatten_state_dict( 29 state_dict: STATE_DICT_TYPE, 30) -> Tuple[STATE_DICT_TYPE, FLATTEN_MAPPING]: 31 """ 32 Flatten ``state_dict`` made of nested dicts and lists into a top level dictionary. 33 34 Use ``unflatten_state_dict`` to revert this process. 35 Returns: 36 A tuple with the flatten state_dict and a mapping from original to new state_dict. 37 N.B. The new keys are derived from the object paths, joined by dot. 38 For example: ``{ 'a': {'b':...}}`` results in the key `a.b`. 39 """ 40 flattened: STATE_DICT_TYPE = {} 41 mappings: FLATTEN_MAPPING = {} 42 43 def flat_copy(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None: 44 new_fqn = ".".join(map(str, path)) 45 if new_fqn in flattened: 46 raise ValueError(f"duplicated flatten key {new_fqn}") 47 flattened[new_fqn] = value 48 mappings[new_fqn] = path 49 50 # We started to flatten dictionary since v2.4. But in order to not break 51 # the checkpoints that were saved before v2.4, we need to keep the old 52 # traversal so that we can reconstruct those checkpoints. 53 use_v_2_3 = ( 54 _version._derived_version is not None and _version._derived_version == "2_3" 55 ) 56 if use_v_2_3: 57 traverse_state_dict_v_2_3(state_dict, flat_copy) 58 else: 59 traverse_state_dict(state_dict, flat_copy) 60 return flattened, mappings 61 62 63def unflatten_state_dict( 64 state_dict: STATE_DICT_TYPE, mapping: FLATTEN_MAPPING 65) -> STATE_DICT_TYPE: 66 """Restore the original nested state_dict according to ``mapping`` and the flattened ``state_dict``.""" 67 nested: STATE_DICT_TYPE = {} 68 for key, value in state_dict.items(): 69 set_element(nested, mapping[key], value) 70 return nested 71