1from typing import Iterable, Optional 2 3import torch 4 5 6def parameters_to_vector(parameters: Iterable[torch.Tensor]) -> torch.Tensor: 7 r"""Flatten an iterable of parameters into a single vector. 8 9 Args: 10 parameters (Iterable[Tensor]): an iterable of Tensors that are the 11 parameters of a model. 12 13 Returns: 14 The parameters represented by a single vector 15 """ 16 # Flag for the device where the parameter is located 17 param_device = None 18 19 vec = [] 20 for param in parameters: 21 # Ensure the parameters are located in the same device 22 param_device = _check_param_device(param, param_device) 23 24 vec.append(param.view(-1)) 25 return torch.cat(vec) 26 27 28def vector_to_parameters(vec: torch.Tensor, parameters: Iterable[torch.Tensor]) -> None: 29 r"""Copy slices of a vector into an iterable of parameters. 30 31 Args: 32 vec (Tensor): a single vector representing the parameters of a model. 33 parameters (Iterable[Tensor]): an iterable of Tensors that are the 34 parameters of a model. 35 """ 36 # Ensure vec of type Tensor 37 if not isinstance(vec, torch.Tensor): 38 raise TypeError(f"expected torch.Tensor, but got: {torch.typename(vec)}") 39 # Flag for the device where the parameter is located 40 param_device = None 41 42 # Pointer for slicing the vector for each parameter 43 pointer = 0 44 for param in parameters: 45 # Ensure the parameters are located in the same device 46 param_device = _check_param_device(param, param_device) 47 48 # The length of the parameter 49 num_param = param.numel() 50 # Slice the vector, reshape it, and replace the old data of the parameter 51 param.data = vec[pointer : pointer + num_param].view_as(param).data 52 53 # Increment the pointer 54 pointer += num_param 55 56 57def _check_param_device(param: torch.Tensor, old_param_device: Optional[int]) -> int: 58 r"""Check if the parameters are located on the same device. 59 60 Currently, the conversion between model parameters and single vector form is not supported 61 for multiple allocations, e.g. parameters in different GPUs/PrivateUse1s, or mixture of CPU/GPU/PrivateUse1. 62 63 Args: 64 param ([Tensor]): a Tensor of a parameter of a model 65 old_param_device (int): the device where the first parameter of a 66 model is allocated. 67 68 Returns: 69 old_param_device (int): report device for the first time 70 """ 71 # Meet the first parameter 72 support_device_types = ["cuda", torch._C._get_privateuse1_backend_name()] 73 if old_param_device is None: 74 old_param_device = ( 75 param.get_device() if param.device.type in support_device_types else -1 76 ) 77 else: 78 warn = False 79 if ( 80 param.device.type in support_device_types 81 ): # Check if in same GPU/PrivateUse1 82 warn = param.get_device() != old_param_device 83 else: # Check if in CPU 84 warn = old_param_device != -1 85 if warn: 86 raise TypeError( 87 "Found two parameters on different devices, " 88 "this is currently not supported." 89 ) 90 return old_param_device 91