• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1from typing import Iterable, Optional
2
3import torch
4
5
6def parameters_to_vector(parameters: Iterable[torch.Tensor]) -> torch.Tensor:
7    r"""Flatten an iterable of parameters into a single vector.
8
9    Args:
10        parameters (Iterable[Tensor]): an iterable of Tensors that are the
11            parameters of a model.
12
13    Returns:
14        The parameters represented by a single vector
15    """
16    # Flag for the device where the parameter is located
17    param_device = None
18
19    vec = []
20    for param in parameters:
21        # Ensure the parameters are located in the same device
22        param_device = _check_param_device(param, param_device)
23
24        vec.append(param.view(-1))
25    return torch.cat(vec)
26
27
28def vector_to_parameters(vec: torch.Tensor, parameters: Iterable[torch.Tensor]) -> None:
29    r"""Copy slices of a vector into an iterable of parameters.
30
31    Args:
32        vec (Tensor): a single vector representing the parameters of a model.
33        parameters (Iterable[Tensor]): an iterable of Tensors that are the
34            parameters of a model.
35    """
36    # Ensure vec of type Tensor
37    if not isinstance(vec, torch.Tensor):
38        raise TypeError(f"expected torch.Tensor, but got: {torch.typename(vec)}")
39    # Flag for the device where the parameter is located
40    param_device = None
41
42    # Pointer for slicing the vector for each parameter
43    pointer = 0
44    for param in parameters:
45        # Ensure the parameters are located in the same device
46        param_device = _check_param_device(param, param_device)
47
48        # The length of the parameter
49        num_param = param.numel()
50        # Slice the vector, reshape it, and replace the old data of the parameter
51        param.data = vec[pointer : pointer + num_param].view_as(param).data
52
53        # Increment the pointer
54        pointer += num_param
55
56
57def _check_param_device(param: torch.Tensor, old_param_device: Optional[int]) -> int:
58    r"""Check if the parameters are located on the same device.
59
60    Currently, the conversion between model parameters and single vector form is not supported
61    for multiple allocations, e.g. parameters in different GPUs/PrivateUse1s, or mixture of CPU/GPU/PrivateUse1.
62
63    Args:
64        param ([Tensor]): a Tensor of a parameter of a model
65        old_param_device (int): the device where the first parameter of a
66                                model is allocated.
67
68    Returns:
69        old_param_device (int): report device for the first time
70    """
71    # Meet the first parameter
72    support_device_types = ["cuda", torch._C._get_privateuse1_backend_name()]
73    if old_param_device is None:
74        old_param_device = (
75            param.get_device() if param.device.type in support_device_types else -1
76        )
77    else:
78        warn = False
79        if (
80            param.device.type in support_device_types
81        ):  # Check if in same GPU/PrivateUse1
82            warn = param.get_device() != old_param_device
83        else:  # Check if in CPU
84            warn = old_param_device != -1
85        if warn:
86            raise TypeError(
87                "Found two parameters on different devices, "
88                "this is currently not supported."
89            )
90    return old_param_device
91