• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #pragma once
2 
3 #include <torch/csrc/Export.h>
4 #include <torch/types.h>
5 
6 namespace torch {
7 namespace nn {
8 namespace utils {
9 
10 // This helper function is to check if the parameters are located
11 // in the same device. Currently, the conversion between model parameters
12 // and single vector form is not supported for multiple allocations,
13 // e.g. parameters in different GPUs, or mixture of CPU/GPU.
_check_param_device(const torch::Tensor & param,std::optional<int64_t> old_param_device)14 inline std::optional<int64_t> _check_param_device(
15     const torch::Tensor& param,
16     std::optional<int64_t> old_param_device) {
17   // Meet the first parameter
18   if (old_param_device == std::nullopt) {
19     old_param_device = param.is_cuda() ? param.get_device() : -1;
20   } else {
21     bool warn = false;
22     if (param.is_cuda()) { // Check if in same GPU
23       warn = (param.get_device() != old_param_device.value());
24     } else { // Check if in CPU
25       warn = (old_param_device.value() != -1);
26     }
27     if (warn) {
28       TORCH_CHECK(
29           false,
30           "Found two parameters on different devices, ",
31           "this is currently not supported.");
32     }
33   }
34 
35   return old_param_device;
36 }
37 
38 // Convert parameters to one vector
parameters_to_vector(const std::vector<torch::Tensor> & parameters)39 inline torch::Tensor parameters_to_vector(
40     const std::vector<torch::Tensor>& parameters) {
41   std::optional<int64_t> param_device;
42 
43   std::vector<torch::Tensor> vec;
44   vec.reserve(parameters.size());
45 
46   for (const torch::Tensor& param : parameters) {
47     // Ensure the parameters are located in the same device
48     param_device = _check_param_device(param, param_device);
49 
50     vec.push_back(param.view(-1));
51   }
52 
53   return torch::cat(vec);
54 }
55 
56 // Convert one vector to the parameters
vector_to_parameters(const torch::Tensor & vec,const std::vector<torch::Tensor> & parameters)57 inline void vector_to_parameters(
58     const torch::Tensor& vec,
59     const std::vector<torch::Tensor>& parameters) {
60   // Flag for the device where the parameter is located
61   std::optional<int64_t> param_device;
62 
63   // Pointer for slicing the vector for each parameter
64   int64_t pointer = 0;
65   for (const torch::Tensor& param : parameters) {
66     // Ensure the parameters are located in the same device
67     param_device = _check_param_device(param, param_device);
68 
69     // The length of the parameter
70     auto num_param = param.numel();
71     // Slice the vector, reshape it, and replace the old data of the parameter
72     param.set_data(
73         vec.slice(0, pointer, pointer + num_param).view_as(param).data());
74 
75     // Increment the pointer
76     pointer += num_param;
77   }
78 }
79 
80 } // namespace utils
81 } // namespace nn
82 } // namespace torch
83