1 #pragma once
2
3 #include <torch/csrc/Export.h>
4 #include <torch/types.h>
5
6 namespace torch {
7 namespace nn {
8 namespace utils {
9
10 // This helper function is to check if the parameters are located
11 // in the same device. Currently, the conversion between model parameters
12 // and single vector form is not supported for multiple allocations,
13 // e.g. parameters in different GPUs, or mixture of CPU/GPU.
_check_param_device(const torch::Tensor & param,std::optional<int64_t> old_param_device)14 inline std::optional<int64_t> _check_param_device(
15 const torch::Tensor& param,
16 std::optional<int64_t> old_param_device) {
17 // Meet the first parameter
18 if (old_param_device == std::nullopt) {
19 old_param_device = param.is_cuda() ? param.get_device() : -1;
20 } else {
21 bool warn = false;
22 if (param.is_cuda()) { // Check if in same GPU
23 warn = (param.get_device() != old_param_device.value());
24 } else { // Check if in CPU
25 warn = (old_param_device.value() != -1);
26 }
27 if (warn) {
28 TORCH_CHECK(
29 false,
30 "Found two parameters on different devices, ",
31 "this is currently not supported.");
32 }
33 }
34
35 return old_param_device;
36 }
37
38 // Convert parameters to one vector
parameters_to_vector(const std::vector<torch::Tensor> & parameters)39 inline torch::Tensor parameters_to_vector(
40 const std::vector<torch::Tensor>& parameters) {
41 std::optional<int64_t> param_device;
42
43 std::vector<torch::Tensor> vec;
44 vec.reserve(parameters.size());
45
46 for (const torch::Tensor& param : parameters) {
47 // Ensure the parameters are located in the same device
48 param_device = _check_param_device(param, param_device);
49
50 vec.push_back(param.view(-1));
51 }
52
53 return torch::cat(vec);
54 }
55
56 // Convert one vector to the parameters
vector_to_parameters(const torch::Tensor & vec,const std::vector<torch::Tensor> & parameters)57 inline void vector_to_parameters(
58 const torch::Tensor& vec,
59 const std::vector<torch::Tensor>& parameters) {
60 // Flag for the device where the parameter is located
61 std::optional<int64_t> param_device;
62
63 // Pointer for slicing the vector for each parameter
64 int64_t pointer = 0;
65 for (const torch::Tensor& param : parameters) {
66 // Ensure the parameters are located in the same device
67 param_device = _check_param_device(param, param_device);
68
69 // The length of the parameter
70 auto num_param = param.numel();
71 // Slice the vector, reshape it, and replace the old data of the parameter
72 param.set_data(
73 vec.slice(0, pointer, pointer + num_param).view_as(param).data());
74
75 // Increment the pointer
76 pointer += num_param;
77 }
78 }
79
80 } // namespace utils
81 } // namespace nn
82 } // namespace torch
83