• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2import copy
3import io
4import math
5import weakref
6from typing import (
7    Any,
8    Callable,
9    cast,
10    Dict,
11    List,
12    Mapping,
13    MutableMapping,
14    NamedTuple,
15    Optional,
16    Tuple,
17    TYPE_CHECKING,
18    Union,
19)
20
21import torch
22import torch.distributed as dist
23import torch.nn.functional as F
24from torch.distributed._functional_collectives import AsyncCollectiveTensor
25
26
27if dist.is_available() or TYPE_CHECKING:
28    from torch.distributed import distributed_c10d
29    from torch.distributed._shard.sharded_tensor import ShardedTensor
30    from torch.distributed.tensor import distribute_tensor, DTensor, Replicate
31    from torch.distributed.tensor._utils import compute_local_shape_and_global_offset
32
33
34def _identity_func(
35    obj: torch.Tensor,
36    pg: Optional[dist.ProcessGroup],
37    device: Optional[torch.device],
38    companion_obj: Any,
39) -> torch.Tensor:
40    return obj
41
42
43def _all_gather_sharded_tensor(
44    sharded_tensor: "ShardedTensor",
45    pg: Optional[dist.ProcessGroup] = None,
46    device: Optional[torch.device] = None,
47) -> torch.Tensor:
48    if pg is None:
49        pg = distributed_c10d._get_default_group()
50    world_size = dist.get_world_size(pg)
51    shards = sharded_tensor.local_shards()
52    dim_0_size = sharded_tensor.size()[0]  # type: ignore[index]
53    tensor_numel = sharded_tensor.size().numel()  # type: ignore[union-attr]
54    chunk_size = math.ceil(dim_0_size / world_size) * tensor_numel // dim_0_size
55    pg_device = (
56        distributed_c10d._get_pg_default_device(pg) if device is None else device
57    )
58    if shards:
59        local_tensor = shards[0].tensor.flatten()
60        if local_tensor.device.type != pg_device.type:
61            local_tensor = local_tensor.to(pg_device)
62        num_padding = chunk_size - local_tensor.numel()
63        if num_padding > 0:
64            local_tensor = F.pad(local_tensor, [0, num_padding])
65    else:
66        local_tensor = torch.zeros(
67            chunk_size, dtype=sharded_tensor.dtype, device=pg_device
68        )
69
70    tensor = torch.empty(
71        chunk_size * world_size,
72        dtype=local_tensor.dtype,
73        device=pg_device,
74    )
75    dist.all_gather_into_tensor(tensor, local_tensor, group=pg)
76
77    tensor = tensor.narrow(0, 0, tensor_numel).reshape(sharded_tensor.size())
78    return tensor
79
80
81class CompanionMismatch(Exception):
82    ...
83
84
85def _iterate_state_dict(
86    iter_object: Any,
87    sharded_tensor_func: Callable,
88    dtensor_func: Callable,
89    tensor_func: Callable,
90    *,
91    pg: Optional[dist.ProcessGroup] = None,
92    device: Optional[torch.device] = None,
93    cpu_offload: bool = False,
94    companion_obj: Any = None,
95    ranks_only: Tuple[int, ...] = (),
96    type_check: bool = True,
97    non_blocking: bool = True,
98) -> Dict[str, Any]:
99    """Iterate through the state dict, applying the given functions to each tensor type.
100
101    Args:
102        iter_object (Any): the target state_dict.
103        sharded_tensor_func (Callable): the function to apply to ShardedTensor
104        dtensor_func (Callable): the function to apply to DTensor
105        tensor_func (Callable): the function to apply to Tensor
106        pg (Optional[dist.ProcessGroup]): process group passed to tensor functions
107        device (Optional[torch.device]): device passed to tensor functions
108        cpu_offload (bool): whether to offload the tensors to CPU memory. This option is ignored
109            if a companion_obj is supplied.
110        companion_obj (Any): A companion object to the state dict. If this object
111            is supplied, we attempt to copy the tensor to the companion object.
112        ranks_only (Tuple[int, ...]): if this tuple is empty, all ranks will
113            have the same state_dicts. Otherwise only ranks that in ``ranks_only``
114            have the same state_dicts. Other ranks will get empty state_dicts.
115        type_check (bool): check if the instance data type is a supported type
116            that can be saved by DCP.  The current supported data types are
117            torch.Tensor, DTensor, int, float, str, list, dict, None.
118        non_blocking (bool): whether to use non-blocking copy when copying to the companion object.
119    """
120    # TODO: should we use pytree?
121    cpu_device = torch.device("cpu")
122    if isinstance(iter_object, ShardedTensor):
123        ret = sharded_tensor_func(iter_object, pg, device, companion_obj)
124    elif isinstance(iter_object, DTensor):
125        ret = dtensor_func(iter_object, pg, device, companion_obj)
126    elif isinstance(iter_object, torch.Tensor):
127        ret = tensor_func(iter_object, pg, device, companion_obj)
128    elif (
129        isinstance(iter_object, (int, float, str, bytes, io.BytesIO))
130        or iter_object is None
131    ):
132        ret = iter_object
133    elif isinstance(iter_object, dict):
134        if companion_obj is not None and (
135            not isinstance(companion_obj, dict)
136            or set(companion_obj.keys()) != set(iter_object.keys())
137        ):
138            msg = (
139                ""
140                if isinstance(companion_obj, dict)
141                else f"{set(companion_obj.keys())=} {set(iter_object.keys())=}"
142            )
143            raise CompanionMismatch(msg)
144
145        ret = {
146            key: _iterate_state_dict(
147                value,
148                sharded_tensor_func,
149                dtensor_func,
150                tensor_func,
151                pg=pg,
152                device=device,
153                cpu_offload=cpu_offload,
154                companion_obj=companion_obj[key] if companion_obj is not None else None,
155                ranks_only=ranks_only,
156                type_check=type_check,
157                non_blocking=non_blocking,
158            )
159            for key, value in iter_object.items()
160        }
161    elif isinstance(iter_object, (list, tuple)):
162        if companion_obj is not None and (
163            not isinstance(companion_obj, (list, tuple))
164            or len(companion_obj) != len(iter_object)
165        ):
166            raise CompanionMismatch
167
168        ret = [
169            _iterate_state_dict(
170                v,
171                sharded_tensor_func,
172                dtensor_func,
173                tensor_func,
174                pg=pg,
175                device=device,
176                cpu_offload=cpu_offload,
177                companion_obj=companion_obj[idx] if companion_obj is not None else None,
178                ranks_only=ranks_only,
179                type_check=type_check,
180                non_blocking=non_blocking,
181            )
182            for idx, v in enumerate(iter_object)
183        ]
184        if isinstance(iter_object, tuple):
185            ret = tuple(ret)
186    elif not type_check:
187        ret = copy.deepcopy(iter_object)
188    else:
189        raise ValueError(f"Unexpected value type {type(iter_object)}")
190
191    if not ranks_only or dist.get_rank(pg) in ranks_only:
192        if isinstance(ret, torch.Tensor):
193            if cpu_offload and companion_obj is None:
194                ret = ret.to(cpu_device)
195
196            if companion_obj is not None:
197                # TODO: support DTensor
198                companion_obj.copy_(ret, non_blocking=non_blocking)
199                ret = companion_obj
200    else:
201        ret = {} if isinstance(ret, dict) else None
202
203    return ret
204
205
206def _gather_state_dict(
207    state_dict: Dict[str, Any],
208    *,
209    pg: Optional[dist.ProcessGroup] = None,
210    device: Optional[torch.device] = None,
211    cpu_offload: bool = False,
212    ranks_only: Tuple[int, ...] = (),
213    type_check: bool = True,
214) -> Dict[str, Any]:
215    """
216    Given a state_dict, this API gathers all the ShardedTensors or DTensors in
217    the state_dict.
218
219
220    Args:
221        state_dict (Dict[str, Any]): the target sharded state_dict.
222        pg (Optional[dist.ProcessGroup]): the process group that is used to
223            gather ShardedTensor. Note that gathering a DTensor will use
224            the DeviceMesh. So this argument will be ignored when gathering a
225            DTensor.
226        device: (Optional[torch.device]): the device that is used to
227            perform allgather for ShardedTensor. Note that gathering a DTensor
228            will use the DeviceMesh. So this argument will be ignored when
229            gathering a DTensor.
230        cpu_offload (bool): whether to offload the tensors to CPU memory. The
231            default value is False.
232        ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will
233            have the same state_dicts. Otherwise only ranks that in ``ranks_only``
234            have the same state_dicts. Other ranks will get empty state_dicts.
235        type_check: (bool): check if the instance data type is a supported type
236            that can be saved by DCP.  The current supported data types are
237            torch.Tensor, DTensor, int, float, str, list, dict, None.
238
239    Returns:
240        The gathered state dictionary.
241    """
242
243    def sharded_tensor_func(value, pg, device, companion_obj):
244        # ShardedTensor does not seem to record the original device type.
245        # So if the tensor is moved to CPU, we won't know the original type.
246        # As a result, we have to rely on the user to tell us the correct one.
247        cpu_device = torch.device("cpu")
248        output_tensor = _all_gather_sharded_tensor(value, pg, device)
249        local_shard_device = (
250            value.local_shards()[0].tensor.device
251            if value.local_shards()
252            else cpu_device
253        )
254        if output_tensor.device != local_shard_device:
255            value = output_tensor.to(local_shard_device)
256        else:
257            value = output_tensor
258        return value
259
260    def dtensor_func(value, pg, device, companion_obj):
261        if value.device != value.device_mesh.device_type:
262            value = value.to(value.device_mesh.device_type)
263        # FSDP all_gather: [Shard(0)] -> [Replicate()]
264        # HSDP all_gather: [Replicate(), Shard(0)] -> [Replicate(), Replicate()]
265        # 2D FSDP + TP all_gather:
266        # - [Shard(0), Shard(n)] -> [Replicate(), Replicate()]
267        # - [Shard(0), Replicate()] -> [Replicate(), Replicate()]
268        placements = [Replicate() for _ in value.placements]
269        value = value.redistribute(
270            device_mesh=value.device_mesh,
271            placements=placements,
272        )
273        # Call `wait()` to force the tensor to be synchronous with respect
274        # to the main stream.
275        # See the discussion in https://github.com/pytorch/pytorch/pull/117799.
276        value = value.to_local()
277        if isinstance(value, AsyncCollectiveTensor):
278            value = value.wait()
279        return value
280
281    return _iterate_state_dict(
282        state_dict,
283        sharded_tensor_func,
284        dtensor_func,
285        _identity_func,
286        pg=pg,
287        device=device,
288        cpu_offload=cpu_offload,
289        ranks_only=ranks_only,
290        type_check=type_check,
291    )
292
293
294def _offload_state_dict_to_cpu(
295    state_dict: Dict[str, Any],
296    *,
297    ranks_only: Tuple[int, ...] = (),
298    type_check: bool = True,
299) -> Dict[str, Any]:
300    """
301    Given a state_dict, this API offload all the tensors to CPU memory.
302
303    Args:
304        state_dict (Dict[str, Any]): the target state_dict.
305        pg (Optional[dist.ProcessGroup]): the process group that is used to
306            gather ShardedTensor. Note that gathering a DTensor will use
307            the DeviceMesh. So this argument will be ignored when gathering a
308            DTensor.
309        ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will
310            have the same state_dicts. Otherwise only ranks that in ``ranks_only``
311            have the same state_dicts. Other ranks will get empty state_dicts.
312        type_check: (bool): check if the instance data type is a supported type
313            that can be saved by DCP.  The current supported data types are
314            torch.Tensor, DTensor, int, float, str, list, dict, None.
315
316    Returns:
317        The gathered state dictionary.
318    """
319
320    ret = _iterate_state_dict(
321        state_dict,
322        _identity_func,
323        _identity_func,
324        _identity_func,
325        pg=None,
326        device=None,
327        cpu_offload=True,
328        ranks_only=ranks_only,
329        type_check=type_check,
330    )
331    return ret
332
333
334def _copy_state_dict(
335    state_dict: Dict[str, Any],
336    copy_state_dict: Dict[str, Any],
337    non_blocking: bool = False,
338    type_check: bool = True,
339) -> Dict[str, Any]:
340    """
341    Copies all tensors in a given state dict into a different state_dict with the
342    same structure. Additionally, a copied state dict with the same value references
343    is returned. Editing the keys on this state dict will not affect the
344    passed in copy_state_dict (but the value references are the same).
345
346    .. warning::
347        It is expected by this function that state_dict and copy_state_dict share
348        the same structure and data types.
349
350    .. warning::
351        The current supported data types are
352            torch.Tensor, DTensor, int, float, str, list, dict, None.
353
354    Args:
355        state_dict (Dict[str, Any]): the target state_dict.
356        copy_state_dict (Dict[str, Any]):
357            The state dict we are copying into. This state_dict must have exactly
358             the same structure as the source `state_dict`.
359        non_blocking: (bool): Whether copy ops should be performed asynchronously
360        type_check (bool): check if the instance data type is a supported type
361            that can be saved by DCP. The current supported data types are
362            torch.Tensor, DTensor, int, float, str, list, dict, None.
363
364    Returns:
365        State Dict copy
366    """
367
368    return _iterate_state_dict(
369        state_dict,
370        _identity_func,
371        _identity_func,
372        _identity_func,
373        pg=None,
374        device=None,
375        cpu_offload=False,
376        ranks_only=(),
377        companion_obj=copy_state_dict,
378        type_check=type_check,
379        non_blocking=non_blocking,
380    )
381
382
383def _create_cpu_state_dict(
384    state_dict: Dict[str, Any], pin_memory: bool = False, share_memory: bool = False
385) -> Dict[str, Any]:
386    """
387    Given a state_dict, create another state_dict with the same structure and elements.
388    However, all tensors in the returned state_dict are new tensors on CPU. These
389    tensors can be placed on pin_memory or share_memory based on the provided arguments.
390
391    .. warning::
392        Setting both `pin_memory` and `share_memory` to True significantly increases the
393        latency of this method because of the nuances which require us to register memory
394        as pinned directly as opposed to relying on the pin_memory cache allocator. This
395        option should only be used for long lived tensors which are required to be shared.
396        This is not the case as long as at least one of `pin_memory` or `share_memory` is
397         set to False.
398
399    """
400
401    def tensor_func(
402        obj: torch.Tensor,
403        pg: Optional[dist.ProcessGroup],
404        device: Optional[torch.device],
405        _: Any,
406    ) -> torch.Tensor:
407        if len(obj.size()) == 0:
408            return torch.tensor(0, dtype=obj.dtype)
409
410        if share_memory:
411            t = torch.empty(*tuple(obj.size()), dtype=obj.dtype)
412            t = t.share_memory_()
413            if pin_memory:
414
415                def unpin_memory(t):
416                    succ = int(torch.cuda.cudart().cudaHostUnregister(t.data_ptr()))
417                    assert (
418                        succ == 0
419                    ), f"Unpinning shared memory failed with error-code: {succ}"
420
421                weakref.finalize(t, unpin_memory, t)
422                succ = int(
423                    torch.cuda.cudart().cudaHostRegister(
424                        t.data_ptr(),
425                        t.numel() * t.element_size(),
426                        1,  # lines up with 'cudaHostRegisterPortable'
427                    )
428                )
429                assert (
430                    succ == 0
431                ), f"Pinning shared memory failed with error-code: {succ}"
432            return t
433        elif pin_memory:
434            return torch.empty(*tuple(obj.size()), dtype=obj.dtype).pin_memory()
435        else:
436            return torch.empty(*tuple(obj.size()), dtype=obj.dtype)
437
438    ret = _iterate_state_dict(
439        state_dict,
440        _identity_func,
441        _identity_func,
442        tensor_func,
443        pg=None,
444        device=None,
445        cpu_offload=False,
446        ranks_only=(),
447        type_check=False,
448    )
449    return ret
450
451
452def _check_state_dict_similarity(
453    state_dict: Dict[str, Any],
454    compared_state_dict: Dict[str, Any],
455) -> bool:
456    """
457    Given two state_dicts, check if the structures are the same. And
458    if a [key, tensor] pair exist in one state_dict there must be
459    the a corresponding pait, [key, other_tensor], in the other state_dict,
460    where tensor and other_tensor have the same size and dtype.
461
462    Return the check result.
463    """
464
465    def tensor_func(
466        obj: torch.Tensor,
467        pg: Optional[dist.ProcessGroup],
468        device: Optional[torch.device],
469        companion_obj: Any,
470    ) -> torch.Tensor:
471        if companion_obj.dtype != obj.dtype or companion_obj.size() != obj.size():
472            raise CompanionMismatch
473        return obj
474
475    try:
476        _iterate_state_dict(
477            state_dict,
478            _identity_func,
479            _identity_func,
480            tensor_func,
481            pg=None,
482            device=None,
483            cpu_offload=False,
484            ranks_only=(),
485            companion_obj=compared_state_dict,
486            type_check=False,
487        )
488    except CompanionMismatch:
489        return False
490
491    return True
492
493
494class _TensorInfo(NamedTuple):
495    size: torch.Size
496    dtype: torch.dtype
497
498
499def _broadcast_tensors(
500    full_state_dict: Dict[str, Any],
501    local_state_dict: Dict[str, Any],
502    keys: List[str],
503    device: torch.device,
504    pg: Optional[dist.ProcessGroup] = None,
505) -> None:
506    tensors = []
507    for key in keys:
508        if dist.get_rank() == 0:
509            full_state = full_state_dict[key]
510            assert isinstance(full_state, torch.Tensor)
511            full_tensor = full_state.detach().to(device)
512        else:
513            tensor_info = full_state_dict[key]
514            full_tensor = torch.empty(
515                size=tensor_info.size,
516                device=device,
517                dtype=tensor_info.dtype,
518            )
519
520        tensors.append(full_tensor)
521        local_state = local_state_dict.get(key, None)
522        if local_state is None:
523            continue
524        elif isinstance(local_state, DTensor):
525            local_state_dict[key] = (local_state, full_tensor)
526        else:
527            local_state_dict[key] = full_tensor
528
529    if pg is None:
530        pg = dist.distributed_c10d._get_default_group()
531
532    if len(tensors) > 1:
533        dist._broadcast_coalesced(pg, tensors, 500, 0)
534    else:
535        dist.broadcast(tensors[0], src=0, group=pg)
536
537    _distribute_tensors(local_state_dict, keys, device, pg)
538
539
540def _distribute_tensors(
541    local_state_dict: Dict[str, Any],
542    keys: List[str],
543    device: torch.device,
544    pg: Optional[dist.ProcessGroup] = None,
545) -> None:
546    if pg is None:
547        pg = dist.distributed_c10d._get_default_group()
548    for key in keys:
549        _local_state = local_state_dict.get(key, None)
550        if _local_state is None or torch.is_tensor(_local_state):
551            continue
552
553        local_state = _local_state[0]
554        full_tensor = _local_state[1]
555
556        shape, offset = compute_local_shape_and_global_offset(
557            full_tensor.shape, local_state.device_mesh, local_state.placements
558        )
559        slices = [slice(offset[i], shape[i] + offset[i]) for i in range(len(shape))]
560        local_tensor = full_tensor[slices]
561        # TODO: currently, we cannot handle strided sharding if the dp dimension is not even. For example,
562        # one of the case that is not yet supported is when placements = (Shard(0), _StridedShard(0, sf=2)).
563        local_state_dict[key] = DTensor.from_local(
564            local_tensor,
565            local_state.device_mesh,
566            local_state.placements,
567            shape=local_state.shape,
568            stride=local_state.stride(),
569        )
570
571
572def _broadcast_state_dict(
573    full_state_dict: Dict[str, Any],
574    local_state_dict: Dict[str, Any],
575    device: torch.device,
576    pg: Optional[dist.ProcessGroup] = None,
577    strict: bool = False,
578) -> None:
579    # Broadcast from rank0's `full_state_dict` to all ranks' `local_state_dict`.
580    # If strict is True, any keys in `local_state_dict` but not in `full_state_dict`
581    # will be removed from `local_state_dict`.
582    ret = {}
583    if dist.get_rank() == 0:
584        for key, value in full_state_dict.items():
585            if not torch.is_tensor(value):
586                ret[key] = value
587            elif value.dim() == 0:
588                ret[key] = value.cpu()
589            else:
590                ret[key] = _TensorInfo(value.size(), value.dtype)
591
592    broadcast_list = [ret]
593    dist.broadcast_object_list(broadcast_list, src=0, group=pg)
594    ret = broadcast_list[0]
595
596    # Gather values
597    keys = []
598    local_state_dict_keys = set(local_state_dict.keys())
599    global_keys = set()
600    for key, value in ret.items():
601        global_keys.add(key)
602        if not isinstance(value, _TensorInfo):
603            if key in local_state_dict:
604                local_state_dict[key] = value
605            continue
606
607        if dist.get_rank() == 0:
608            ret[key] = full_state_dict[key]
609
610        keys.append(key)
611        # Broadcast every tensor to avoid OOM for now.
612        if len(keys) >= 1:
613            _broadcast_tensors(ret, local_state_dict, keys, device, pg)
614            keys.clear()
615
616    if strict:
617        if missing_keys := (local_state_dict_keys - global_keys):
618            for key in missing_keys:
619                local_state_dict.pop(key)
620
621    if keys:
622        _broadcast_tensors(ret, local_state_dict, keys, device, pg)
623
624
625def _distribute_state_dict(
626    full_state_dict: Dict[str, Any],
627    local_state_dict: Dict[str, Any],
628    device: torch.device,
629    pg: Optional[dist.ProcessGroup] = None,
630) -> None:
631    # Full_state_dict = True, broadcast_from_rank0 = False here. Each rank has
632    # full_state_dict. Skip the broadcast in ``_broadcast_state_dict`` and
633    # distribute tensors in each rank
634    for key, value in full_state_dict.items():
635        if key not in full_state_dict:
636            continue
637        if not torch.is_tensor(value):
638            local_state_dict[key] = value
639        elif value.dim() == 0:
640            local_state_dict[key] = value.cpu()
641        else:
642            assert isinstance(value, torch.Tensor)
643            local_state = local_state_dict.get(key, None)
644            if local_state is None:
645                continue
646            elif isinstance(local_state, DTensor):
647                local_state_dict[key] = distribute_tensor(
648                    value.detach().to(device),
649                    local_state.device_mesh,
650                    local_state.placements,
651                )
652            else:
653                local_state_dict[key] = value.detach().to(device)
654
655
656# These APIs are from torch.distributed.checkpoint.
657# TODO: We should consolidate the code here as some not all modules can depend on
658# DCP.
659PATH_ITEM = Union[str, int]
660OBJ_PATH = Tuple[PATH_ITEM, ...]
661FLATTEN_MAPPING = Dict[str, OBJ_PATH]
662STATE_DICT_TYPE = Dict[str, Any]
663CONTAINER_TYPE = MutableMapping[PATH_ITEM, Any]
664
665
666def _traverse_state_dict(
667    state_dict: STATE_DICT_TYPE,
668    visitor: Callable[[OBJ_PATH, Any], None],
669) -> None:
670    """
671    Invoke ``visitor`` for each value recursively in ``state_dict``.
672    Mapping, list, and tuple will be flattened and other value types are treated
673    as the terminal values and will invoke ``visitor``.
674    """
675
676    def _traverse_obj(path: OBJ_PATH, value: Any) -> None:
677        if isinstance(value, Mapping):
678            for k, v in value.items():
679                _traverse_obj(path + (str(k),), v)
680        elif isinstance(value, (list, tuple)):
681            for i, v in enumerate(value):
682                _traverse_obj(path + (i,), v)
683        else:
684            visitor(path, value)
685
686    for key, value in state_dict.items():
687        _traverse_obj((str(key),), value)
688
689
690def _flatten_state_dict(
691    state_dict: STATE_DICT_TYPE,
692) -> Tuple[STATE_DICT_TYPE, FLATTEN_MAPPING]:
693    """
694    Flatten ``state_dict`` made of nested dicts and lists into a top level dictionary.
695
696    Use ``unflatten_state_dict`` to revert this process.
697    Returns:
698        A tuple with the flatten state_dict and a mapping from original to new state_dict.
699    N.B. The new keys are derived from the object paths, joined by dot.
700        For example: ``{ 'a': {'b':...}}`` results in the key `a.b`.
701    """
702    flattened: STATE_DICT_TYPE = {}
703    mappings: FLATTEN_MAPPING = {}
704
705    def flat_copy(path: OBJ_PATH, value: Any) -> None:
706        new_fqn = ".".join(map(str, path))
707        if new_fqn in flattened:
708            raise ValueError(f"duplicated flatten key {new_fqn}")
709        flattened[new_fqn] = value
710        mappings[new_fqn] = path
711
712    _traverse_state_dict(state_dict, flat_copy)
713    return flattened, mappings
714
715
716def _set_element(root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: Any) -> None:
717    """Set ``value`` in ``root_dict`` along the ``path`` object path."""
718    cur_container = cast(CONTAINER_TYPE, root_dict)
719
720    def extend_list(lst: List[Any], idx: int) -> None:
721        while len(lst) <= idx:
722            lst.append(None)
723
724    for i in range(1, len(path)):
725        prev_key = path[i - 1]
726        key = path[i]
727        def_val: Union[CONTAINER_TYPE, List[Any]] = {} if type(key) == str else []
728
729        if isinstance(cur_container, Mapping):
730            cur_container = cast(
731                CONTAINER_TYPE, cur_container.setdefault(prev_key, def_val)
732            )
733        else:
734            extend_list(cur_container, prev_key)
735            if cur_container[prev_key] is None:
736                cur_container[prev_key] = def_val
737            cur_container = cur_container[prev_key]
738
739    key = path[-1]
740    if type(key) == int:
741        extend_list(cast(List[Any], cur_container), key)
742
743    cur_container[key] = value
744
745
746def _unflatten_state_dict(
747    state_dict: STATE_DICT_TYPE, mapping: FLATTEN_MAPPING
748) -> STATE_DICT_TYPE:
749    """Restore the original nested state_dict according to ``mapping`` and the flattened ``state_dict``."""
750    nested: STATE_DICT_TYPE = {}
751    for key, value in state_dict.items():
752        _set_element(nested, mapping[key], value)
753    return nested
754