1# mypy: allow-untyped-defs 2import copy 3import io 4import math 5import weakref 6from typing import ( 7 Any, 8 Callable, 9 cast, 10 Dict, 11 List, 12 Mapping, 13 MutableMapping, 14 NamedTuple, 15 Optional, 16 Tuple, 17 TYPE_CHECKING, 18 Union, 19) 20 21import torch 22import torch.distributed as dist 23import torch.nn.functional as F 24from torch.distributed._functional_collectives import AsyncCollectiveTensor 25 26 27if dist.is_available() or TYPE_CHECKING: 28 from torch.distributed import distributed_c10d 29 from torch.distributed._shard.sharded_tensor import ShardedTensor 30 from torch.distributed.tensor import distribute_tensor, DTensor, Replicate 31 from torch.distributed.tensor._utils import compute_local_shape_and_global_offset 32 33 34def _identity_func( 35 obj: torch.Tensor, 36 pg: Optional[dist.ProcessGroup], 37 device: Optional[torch.device], 38 companion_obj: Any, 39) -> torch.Tensor: 40 return obj 41 42 43def _all_gather_sharded_tensor( 44 sharded_tensor: "ShardedTensor", 45 pg: Optional[dist.ProcessGroup] = None, 46 device: Optional[torch.device] = None, 47) -> torch.Tensor: 48 if pg is None: 49 pg = distributed_c10d._get_default_group() 50 world_size = dist.get_world_size(pg) 51 shards = sharded_tensor.local_shards() 52 dim_0_size = sharded_tensor.size()[0] # type: ignore[index] 53 tensor_numel = sharded_tensor.size().numel() # type: ignore[union-attr] 54 chunk_size = math.ceil(dim_0_size / world_size) * tensor_numel // dim_0_size 55 pg_device = ( 56 distributed_c10d._get_pg_default_device(pg) if device is None else device 57 ) 58 if shards: 59 local_tensor = shards[0].tensor.flatten() 60 if local_tensor.device.type != pg_device.type: 61 local_tensor = local_tensor.to(pg_device) 62 num_padding = chunk_size - local_tensor.numel() 63 if num_padding > 0: 64 local_tensor = F.pad(local_tensor, [0, num_padding]) 65 else: 66 local_tensor = torch.zeros( 67 chunk_size, dtype=sharded_tensor.dtype, device=pg_device 68 ) 69 70 tensor = torch.empty( 71 chunk_size * world_size, 72 dtype=local_tensor.dtype, 73 device=pg_device, 74 ) 75 dist.all_gather_into_tensor(tensor, local_tensor, group=pg) 76 77 tensor = tensor.narrow(0, 0, tensor_numel).reshape(sharded_tensor.size()) 78 return tensor 79 80 81class CompanionMismatch(Exception): 82 ... 83 84 85def _iterate_state_dict( 86 iter_object: Any, 87 sharded_tensor_func: Callable, 88 dtensor_func: Callable, 89 tensor_func: Callable, 90 *, 91 pg: Optional[dist.ProcessGroup] = None, 92 device: Optional[torch.device] = None, 93 cpu_offload: bool = False, 94 companion_obj: Any = None, 95 ranks_only: Tuple[int, ...] = (), 96 type_check: bool = True, 97 non_blocking: bool = True, 98) -> Dict[str, Any]: 99 """Iterate through the state dict, applying the given functions to each tensor type. 100 101 Args: 102 iter_object (Any): the target state_dict. 103 sharded_tensor_func (Callable): the function to apply to ShardedTensor 104 dtensor_func (Callable): the function to apply to DTensor 105 tensor_func (Callable): the function to apply to Tensor 106 pg (Optional[dist.ProcessGroup]): process group passed to tensor functions 107 device (Optional[torch.device]): device passed to tensor functions 108 cpu_offload (bool): whether to offload the tensors to CPU memory. This option is ignored 109 if a companion_obj is supplied. 110 companion_obj (Any): A companion object to the state dict. If this object 111 is supplied, we attempt to copy the tensor to the companion object. 112 ranks_only (Tuple[int, ...]): if this tuple is empty, all ranks will 113 have the same state_dicts. Otherwise only ranks that in ``ranks_only`` 114 have the same state_dicts. Other ranks will get empty state_dicts. 115 type_check (bool): check if the instance data type is a supported type 116 that can be saved by DCP. The current supported data types are 117 torch.Tensor, DTensor, int, float, str, list, dict, None. 118 non_blocking (bool): whether to use non-blocking copy when copying to the companion object. 119 """ 120 # TODO: should we use pytree? 121 cpu_device = torch.device("cpu") 122 if isinstance(iter_object, ShardedTensor): 123 ret = sharded_tensor_func(iter_object, pg, device, companion_obj) 124 elif isinstance(iter_object, DTensor): 125 ret = dtensor_func(iter_object, pg, device, companion_obj) 126 elif isinstance(iter_object, torch.Tensor): 127 ret = tensor_func(iter_object, pg, device, companion_obj) 128 elif ( 129 isinstance(iter_object, (int, float, str, bytes, io.BytesIO)) 130 or iter_object is None 131 ): 132 ret = iter_object 133 elif isinstance(iter_object, dict): 134 if companion_obj is not None and ( 135 not isinstance(companion_obj, dict) 136 or set(companion_obj.keys()) != set(iter_object.keys()) 137 ): 138 msg = ( 139 "" 140 if isinstance(companion_obj, dict) 141 else f"{set(companion_obj.keys())=} {set(iter_object.keys())=}" 142 ) 143 raise CompanionMismatch(msg) 144 145 ret = { 146 key: _iterate_state_dict( 147 value, 148 sharded_tensor_func, 149 dtensor_func, 150 tensor_func, 151 pg=pg, 152 device=device, 153 cpu_offload=cpu_offload, 154 companion_obj=companion_obj[key] if companion_obj is not None else None, 155 ranks_only=ranks_only, 156 type_check=type_check, 157 non_blocking=non_blocking, 158 ) 159 for key, value in iter_object.items() 160 } 161 elif isinstance(iter_object, (list, tuple)): 162 if companion_obj is not None and ( 163 not isinstance(companion_obj, (list, tuple)) 164 or len(companion_obj) != len(iter_object) 165 ): 166 raise CompanionMismatch 167 168 ret = [ 169 _iterate_state_dict( 170 v, 171 sharded_tensor_func, 172 dtensor_func, 173 tensor_func, 174 pg=pg, 175 device=device, 176 cpu_offload=cpu_offload, 177 companion_obj=companion_obj[idx] if companion_obj is not None else None, 178 ranks_only=ranks_only, 179 type_check=type_check, 180 non_blocking=non_blocking, 181 ) 182 for idx, v in enumerate(iter_object) 183 ] 184 if isinstance(iter_object, tuple): 185 ret = tuple(ret) 186 elif not type_check: 187 ret = copy.deepcopy(iter_object) 188 else: 189 raise ValueError(f"Unexpected value type {type(iter_object)}") 190 191 if not ranks_only or dist.get_rank(pg) in ranks_only: 192 if isinstance(ret, torch.Tensor): 193 if cpu_offload and companion_obj is None: 194 ret = ret.to(cpu_device) 195 196 if companion_obj is not None: 197 # TODO: support DTensor 198 companion_obj.copy_(ret, non_blocking=non_blocking) 199 ret = companion_obj 200 else: 201 ret = {} if isinstance(ret, dict) else None 202 203 return ret 204 205 206def _gather_state_dict( 207 state_dict: Dict[str, Any], 208 *, 209 pg: Optional[dist.ProcessGroup] = None, 210 device: Optional[torch.device] = None, 211 cpu_offload: bool = False, 212 ranks_only: Tuple[int, ...] = (), 213 type_check: bool = True, 214) -> Dict[str, Any]: 215 """ 216 Given a state_dict, this API gathers all the ShardedTensors or DTensors in 217 the state_dict. 218 219 220 Args: 221 state_dict (Dict[str, Any]): the target sharded state_dict. 222 pg (Optional[dist.ProcessGroup]): the process group that is used to 223 gather ShardedTensor. Note that gathering a DTensor will use 224 the DeviceMesh. So this argument will be ignored when gathering a 225 DTensor. 226 device: (Optional[torch.device]): the device that is used to 227 perform allgather for ShardedTensor. Note that gathering a DTensor 228 will use the DeviceMesh. So this argument will be ignored when 229 gathering a DTensor. 230 cpu_offload (bool): whether to offload the tensors to CPU memory. The 231 default value is False. 232 ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will 233 have the same state_dicts. Otherwise only ranks that in ``ranks_only`` 234 have the same state_dicts. Other ranks will get empty state_dicts. 235 type_check: (bool): check if the instance data type is a supported type 236 that can be saved by DCP. The current supported data types are 237 torch.Tensor, DTensor, int, float, str, list, dict, None. 238 239 Returns: 240 The gathered state dictionary. 241 """ 242 243 def sharded_tensor_func(value, pg, device, companion_obj): 244 # ShardedTensor does not seem to record the original device type. 245 # So if the tensor is moved to CPU, we won't know the original type. 246 # As a result, we have to rely on the user to tell us the correct one. 247 cpu_device = torch.device("cpu") 248 output_tensor = _all_gather_sharded_tensor(value, pg, device) 249 local_shard_device = ( 250 value.local_shards()[0].tensor.device 251 if value.local_shards() 252 else cpu_device 253 ) 254 if output_tensor.device != local_shard_device: 255 value = output_tensor.to(local_shard_device) 256 else: 257 value = output_tensor 258 return value 259 260 def dtensor_func(value, pg, device, companion_obj): 261 if value.device != value.device_mesh.device_type: 262 value = value.to(value.device_mesh.device_type) 263 # FSDP all_gather: [Shard(0)] -> [Replicate()] 264 # HSDP all_gather: [Replicate(), Shard(0)] -> [Replicate(), Replicate()] 265 # 2D FSDP + TP all_gather: 266 # - [Shard(0), Shard(n)] -> [Replicate(), Replicate()] 267 # - [Shard(0), Replicate()] -> [Replicate(), Replicate()] 268 placements = [Replicate() for _ in value.placements] 269 value = value.redistribute( 270 device_mesh=value.device_mesh, 271 placements=placements, 272 ) 273 # Call `wait()` to force the tensor to be synchronous with respect 274 # to the main stream. 275 # See the discussion in https://github.com/pytorch/pytorch/pull/117799. 276 value = value.to_local() 277 if isinstance(value, AsyncCollectiveTensor): 278 value = value.wait() 279 return value 280 281 return _iterate_state_dict( 282 state_dict, 283 sharded_tensor_func, 284 dtensor_func, 285 _identity_func, 286 pg=pg, 287 device=device, 288 cpu_offload=cpu_offload, 289 ranks_only=ranks_only, 290 type_check=type_check, 291 ) 292 293 294def _offload_state_dict_to_cpu( 295 state_dict: Dict[str, Any], 296 *, 297 ranks_only: Tuple[int, ...] = (), 298 type_check: bool = True, 299) -> Dict[str, Any]: 300 """ 301 Given a state_dict, this API offload all the tensors to CPU memory. 302 303 Args: 304 state_dict (Dict[str, Any]): the target state_dict. 305 pg (Optional[dist.ProcessGroup]): the process group that is used to 306 gather ShardedTensor. Note that gathering a DTensor will use 307 the DeviceMesh. So this argument will be ignored when gathering a 308 DTensor. 309 ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will 310 have the same state_dicts. Otherwise only ranks that in ``ranks_only`` 311 have the same state_dicts. Other ranks will get empty state_dicts. 312 type_check: (bool): check if the instance data type is a supported type 313 that can be saved by DCP. The current supported data types are 314 torch.Tensor, DTensor, int, float, str, list, dict, None. 315 316 Returns: 317 The gathered state dictionary. 318 """ 319 320 ret = _iterate_state_dict( 321 state_dict, 322 _identity_func, 323 _identity_func, 324 _identity_func, 325 pg=None, 326 device=None, 327 cpu_offload=True, 328 ranks_only=ranks_only, 329 type_check=type_check, 330 ) 331 return ret 332 333 334def _copy_state_dict( 335 state_dict: Dict[str, Any], 336 copy_state_dict: Dict[str, Any], 337 non_blocking: bool = False, 338 type_check: bool = True, 339) -> Dict[str, Any]: 340 """ 341 Copies all tensors in a given state dict into a different state_dict with the 342 same structure. Additionally, a copied state dict with the same value references 343 is returned. Editing the keys on this state dict will not affect the 344 passed in copy_state_dict (but the value references are the same). 345 346 .. warning:: 347 It is expected by this function that state_dict and copy_state_dict share 348 the same structure and data types. 349 350 .. warning:: 351 The current supported data types are 352 torch.Tensor, DTensor, int, float, str, list, dict, None. 353 354 Args: 355 state_dict (Dict[str, Any]): the target state_dict. 356 copy_state_dict (Dict[str, Any]): 357 The state dict we are copying into. This state_dict must have exactly 358 the same structure as the source `state_dict`. 359 non_blocking: (bool): Whether copy ops should be performed asynchronously 360 type_check (bool): check if the instance data type is a supported type 361 that can be saved by DCP. The current supported data types are 362 torch.Tensor, DTensor, int, float, str, list, dict, None. 363 364 Returns: 365 State Dict copy 366 """ 367 368 return _iterate_state_dict( 369 state_dict, 370 _identity_func, 371 _identity_func, 372 _identity_func, 373 pg=None, 374 device=None, 375 cpu_offload=False, 376 ranks_only=(), 377 companion_obj=copy_state_dict, 378 type_check=type_check, 379 non_blocking=non_blocking, 380 ) 381 382 383def _create_cpu_state_dict( 384 state_dict: Dict[str, Any], pin_memory: bool = False, share_memory: bool = False 385) -> Dict[str, Any]: 386 """ 387 Given a state_dict, create another state_dict with the same structure and elements. 388 However, all tensors in the returned state_dict are new tensors on CPU. These 389 tensors can be placed on pin_memory or share_memory based on the provided arguments. 390 391 .. warning:: 392 Setting both `pin_memory` and `share_memory` to True significantly increases the 393 latency of this method because of the nuances which require us to register memory 394 as pinned directly as opposed to relying on the pin_memory cache allocator. This 395 option should only be used for long lived tensors which are required to be shared. 396 This is not the case as long as at least one of `pin_memory` or `share_memory` is 397 set to False. 398 399 """ 400 401 def tensor_func( 402 obj: torch.Tensor, 403 pg: Optional[dist.ProcessGroup], 404 device: Optional[torch.device], 405 _: Any, 406 ) -> torch.Tensor: 407 if len(obj.size()) == 0: 408 return torch.tensor(0, dtype=obj.dtype) 409 410 if share_memory: 411 t = torch.empty(*tuple(obj.size()), dtype=obj.dtype) 412 t = t.share_memory_() 413 if pin_memory: 414 415 def unpin_memory(t): 416 succ = int(torch.cuda.cudart().cudaHostUnregister(t.data_ptr())) 417 assert ( 418 succ == 0 419 ), f"Unpinning shared memory failed with error-code: {succ}" 420 421 weakref.finalize(t, unpin_memory, t) 422 succ = int( 423 torch.cuda.cudart().cudaHostRegister( 424 t.data_ptr(), 425 t.numel() * t.element_size(), 426 1, # lines up with 'cudaHostRegisterPortable' 427 ) 428 ) 429 assert ( 430 succ == 0 431 ), f"Pinning shared memory failed with error-code: {succ}" 432 return t 433 elif pin_memory: 434 return torch.empty(*tuple(obj.size()), dtype=obj.dtype).pin_memory() 435 else: 436 return torch.empty(*tuple(obj.size()), dtype=obj.dtype) 437 438 ret = _iterate_state_dict( 439 state_dict, 440 _identity_func, 441 _identity_func, 442 tensor_func, 443 pg=None, 444 device=None, 445 cpu_offload=False, 446 ranks_only=(), 447 type_check=False, 448 ) 449 return ret 450 451 452def _check_state_dict_similarity( 453 state_dict: Dict[str, Any], 454 compared_state_dict: Dict[str, Any], 455) -> bool: 456 """ 457 Given two state_dicts, check if the structures are the same. And 458 if a [key, tensor] pair exist in one state_dict there must be 459 the a corresponding pait, [key, other_tensor], in the other state_dict, 460 where tensor and other_tensor have the same size and dtype. 461 462 Return the check result. 463 """ 464 465 def tensor_func( 466 obj: torch.Tensor, 467 pg: Optional[dist.ProcessGroup], 468 device: Optional[torch.device], 469 companion_obj: Any, 470 ) -> torch.Tensor: 471 if companion_obj.dtype != obj.dtype or companion_obj.size() != obj.size(): 472 raise CompanionMismatch 473 return obj 474 475 try: 476 _iterate_state_dict( 477 state_dict, 478 _identity_func, 479 _identity_func, 480 tensor_func, 481 pg=None, 482 device=None, 483 cpu_offload=False, 484 ranks_only=(), 485 companion_obj=compared_state_dict, 486 type_check=False, 487 ) 488 except CompanionMismatch: 489 return False 490 491 return True 492 493 494class _TensorInfo(NamedTuple): 495 size: torch.Size 496 dtype: torch.dtype 497 498 499def _broadcast_tensors( 500 full_state_dict: Dict[str, Any], 501 local_state_dict: Dict[str, Any], 502 keys: List[str], 503 device: torch.device, 504 pg: Optional[dist.ProcessGroup] = None, 505) -> None: 506 tensors = [] 507 for key in keys: 508 if dist.get_rank() == 0: 509 full_state = full_state_dict[key] 510 assert isinstance(full_state, torch.Tensor) 511 full_tensor = full_state.detach().to(device) 512 else: 513 tensor_info = full_state_dict[key] 514 full_tensor = torch.empty( 515 size=tensor_info.size, 516 device=device, 517 dtype=tensor_info.dtype, 518 ) 519 520 tensors.append(full_tensor) 521 local_state = local_state_dict.get(key, None) 522 if local_state is None: 523 continue 524 elif isinstance(local_state, DTensor): 525 local_state_dict[key] = (local_state, full_tensor) 526 else: 527 local_state_dict[key] = full_tensor 528 529 if pg is None: 530 pg = dist.distributed_c10d._get_default_group() 531 532 if len(tensors) > 1: 533 dist._broadcast_coalesced(pg, tensors, 500, 0) 534 else: 535 dist.broadcast(tensors[0], src=0, group=pg) 536 537 _distribute_tensors(local_state_dict, keys, device, pg) 538 539 540def _distribute_tensors( 541 local_state_dict: Dict[str, Any], 542 keys: List[str], 543 device: torch.device, 544 pg: Optional[dist.ProcessGroup] = None, 545) -> None: 546 if pg is None: 547 pg = dist.distributed_c10d._get_default_group() 548 for key in keys: 549 _local_state = local_state_dict.get(key, None) 550 if _local_state is None or torch.is_tensor(_local_state): 551 continue 552 553 local_state = _local_state[0] 554 full_tensor = _local_state[1] 555 556 shape, offset = compute_local_shape_and_global_offset( 557 full_tensor.shape, local_state.device_mesh, local_state.placements 558 ) 559 slices = [slice(offset[i], shape[i] + offset[i]) for i in range(len(shape))] 560 local_tensor = full_tensor[slices] 561 # TODO: currently, we cannot handle strided sharding if the dp dimension is not even. For example, 562 # one of the case that is not yet supported is when placements = (Shard(0), _StridedShard(0, sf=2)). 563 local_state_dict[key] = DTensor.from_local( 564 local_tensor, 565 local_state.device_mesh, 566 local_state.placements, 567 shape=local_state.shape, 568 stride=local_state.stride(), 569 ) 570 571 572def _broadcast_state_dict( 573 full_state_dict: Dict[str, Any], 574 local_state_dict: Dict[str, Any], 575 device: torch.device, 576 pg: Optional[dist.ProcessGroup] = None, 577 strict: bool = False, 578) -> None: 579 # Broadcast from rank0's `full_state_dict` to all ranks' `local_state_dict`. 580 # If strict is True, any keys in `local_state_dict` but not in `full_state_dict` 581 # will be removed from `local_state_dict`. 582 ret = {} 583 if dist.get_rank() == 0: 584 for key, value in full_state_dict.items(): 585 if not torch.is_tensor(value): 586 ret[key] = value 587 elif value.dim() == 0: 588 ret[key] = value.cpu() 589 else: 590 ret[key] = _TensorInfo(value.size(), value.dtype) 591 592 broadcast_list = [ret] 593 dist.broadcast_object_list(broadcast_list, src=0, group=pg) 594 ret = broadcast_list[0] 595 596 # Gather values 597 keys = [] 598 local_state_dict_keys = set(local_state_dict.keys()) 599 global_keys = set() 600 for key, value in ret.items(): 601 global_keys.add(key) 602 if not isinstance(value, _TensorInfo): 603 if key in local_state_dict: 604 local_state_dict[key] = value 605 continue 606 607 if dist.get_rank() == 0: 608 ret[key] = full_state_dict[key] 609 610 keys.append(key) 611 # Broadcast every tensor to avoid OOM for now. 612 if len(keys) >= 1: 613 _broadcast_tensors(ret, local_state_dict, keys, device, pg) 614 keys.clear() 615 616 if strict: 617 if missing_keys := (local_state_dict_keys - global_keys): 618 for key in missing_keys: 619 local_state_dict.pop(key) 620 621 if keys: 622 _broadcast_tensors(ret, local_state_dict, keys, device, pg) 623 624 625def _distribute_state_dict( 626 full_state_dict: Dict[str, Any], 627 local_state_dict: Dict[str, Any], 628 device: torch.device, 629 pg: Optional[dist.ProcessGroup] = None, 630) -> None: 631 # Full_state_dict = True, broadcast_from_rank0 = False here. Each rank has 632 # full_state_dict. Skip the broadcast in ``_broadcast_state_dict`` and 633 # distribute tensors in each rank 634 for key, value in full_state_dict.items(): 635 if key not in full_state_dict: 636 continue 637 if not torch.is_tensor(value): 638 local_state_dict[key] = value 639 elif value.dim() == 0: 640 local_state_dict[key] = value.cpu() 641 else: 642 assert isinstance(value, torch.Tensor) 643 local_state = local_state_dict.get(key, None) 644 if local_state is None: 645 continue 646 elif isinstance(local_state, DTensor): 647 local_state_dict[key] = distribute_tensor( 648 value.detach().to(device), 649 local_state.device_mesh, 650 local_state.placements, 651 ) 652 else: 653 local_state_dict[key] = value.detach().to(device) 654 655 656# These APIs are from torch.distributed.checkpoint. 657# TODO: We should consolidate the code here as some not all modules can depend on 658# DCP. 659PATH_ITEM = Union[str, int] 660OBJ_PATH = Tuple[PATH_ITEM, ...] 661FLATTEN_MAPPING = Dict[str, OBJ_PATH] 662STATE_DICT_TYPE = Dict[str, Any] 663CONTAINER_TYPE = MutableMapping[PATH_ITEM, Any] 664 665 666def _traverse_state_dict( 667 state_dict: STATE_DICT_TYPE, 668 visitor: Callable[[OBJ_PATH, Any], None], 669) -> None: 670 """ 671 Invoke ``visitor`` for each value recursively in ``state_dict``. 672 Mapping, list, and tuple will be flattened and other value types are treated 673 as the terminal values and will invoke ``visitor``. 674 """ 675 676 def _traverse_obj(path: OBJ_PATH, value: Any) -> None: 677 if isinstance(value, Mapping): 678 for k, v in value.items(): 679 _traverse_obj(path + (str(k),), v) 680 elif isinstance(value, (list, tuple)): 681 for i, v in enumerate(value): 682 _traverse_obj(path + (i,), v) 683 else: 684 visitor(path, value) 685 686 for key, value in state_dict.items(): 687 _traverse_obj((str(key),), value) 688 689 690def _flatten_state_dict( 691 state_dict: STATE_DICT_TYPE, 692) -> Tuple[STATE_DICT_TYPE, FLATTEN_MAPPING]: 693 """ 694 Flatten ``state_dict`` made of nested dicts and lists into a top level dictionary. 695 696 Use ``unflatten_state_dict`` to revert this process. 697 Returns: 698 A tuple with the flatten state_dict and a mapping from original to new state_dict. 699 N.B. The new keys are derived from the object paths, joined by dot. 700 For example: ``{ 'a': {'b':...}}`` results in the key `a.b`. 701 """ 702 flattened: STATE_DICT_TYPE = {} 703 mappings: FLATTEN_MAPPING = {} 704 705 def flat_copy(path: OBJ_PATH, value: Any) -> None: 706 new_fqn = ".".join(map(str, path)) 707 if new_fqn in flattened: 708 raise ValueError(f"duplicated flatten key {new_fqn}") 709 flattened[new_fqn] = value 710 mappings[new_fqn] = path 711 712 _traverse_state_dict(state_dict, flat_copy) 713 return flattened, mappings 714 715 716def _set_element(root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: Any) -> None: 717 """Set ``value`` in ``root_dict`` along the ``path`` object path.""" 718 cur_container = cast(CONTAINER_TYPE, root_dict) 719 720 def extend_list(lst: List[Any], idx: int) -> None: 721 while len(lst) <= idx: 722 lst.append(None) 723 724 for i in range(1, len(path)): 725 prev_key = path[i - 1] 726 key = path[i] 727 def_val: Union[CONTAINER_TYPE, List[Any]] = {} if type(key) == str else [] 728 729 if isinstance(cur_container, Mapping): 730 cur_container = cast( 731 CONTAINER_TYPE, cur_container.setdefault(prev_key, def_val) 732 ) 733 else: 734 extend_list(cur_container, prev_key) 735 if cur_container[prev_key] is None: 736 cur_container[prev_key] = def_val 737 cur_container = cur_container[prev_key] 738 739 key = path[-1] 740 if type(key) == int: 741 extend_list(cast(List[Any], cur_container), key) 742 743 cur_container[key] = value 744 745 746def _unflatten_state_dict( 747 state_dict: STATE_DICT_TYPE, mapping: FLATTEN_MAPPING 748) -> STATE_DICT_TYPE: 749 """Restore the original nested state_dict according to ``mapping`` and the flattened ``state_dict``.""" 750 nested: STATE_DICT_TYPE = {} 751 for key, value in state_dict.items(): 752 _set_element(nested, mapping[key], value) 753 return nested 754