• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #pragma once
2 
3 #include <torch/csrc/Export.h>
4 
5 #include <utility>
6 
7 namespace torch {
8 namespace nn {
9 namespace utils {
10 
11 // Clips gradient norm of a vector of Tensors.
12 // See
13 // https://pytorch.org/docs/stable/nn.html?highlight=clip_grad_norm#torch.nn.utils.clip_grad_norm_
14 // for more details about this module.
15 //
16 // Difference with the python version: unlike the python version, even when
17 // skipping the finiteness checks (error_if_nonfinite = false), this function
18 // will introduce a device <=> CPU synchronization (for devices where that makes
19 // sense!) in order to return a CPU-side `double`. This C++ version therefore
20 // cannot be run fully asynchronously w.r.t. the device of the gradients.
21 inline double clip_grad_norm_(
22     const std::vector<Tensor>& parameters,
23     double max_norm,
24     double norm_type = 2.0,
25     bool error_if_nonfinite = false) {
26   std::vector<Tensor> params_with_grad;
27 
28   for (const auto& param : parameters) {
29     auto& grad = param.grad();
30     if (grad.defined()) {
31       params_with_grad.push_back(param);
32     }
33   }
34 
35   if (params_with_grad.empty()) {
36     return 0.0;
37   }
38 
39   Tensor total_norm_tensor;
40   if (norm_type == std::numeric_limits<double>::infinity()) {
41     std::vector<Tensor> norms;
42     norms.reserve(params_with_grad.size());
43 
44     for (const auto& param : params_with_grad) {
45       norms.emplace_back(param.grad().data().abs().max());
46     }
47     total_norm_tensor =
48         (norms.size() == 1) ? norms[0] : torch::max(torch::stack(norms));
49   } else if (norm_type == 0) {
50     total_norm_tensor =
51         torch::full({}, static_cast<double>(params_with_grad.size()));
52   } else {
53     std::vector<Tensor> norms;
54     norms.reserve(params_with_grad.size());
55 
56     for (const auto& param : params_with_grad) {
57       norms.emplace_back(param.grad().data().norm(norm_type));
58     }
59     total_norm_tensor =
60         (norms.size() == 1) ? norms[0] : torch::stack(norms).norm(norm_type);
61   }
62 
63   // When possible (ie when skipping the finiteness check), we avoid
64   // synchronizing the CPU and the gradients' device until the very end to
65   // preserve async execution on the device. When checking for finite-ness, this
66   // optional ensures we only sync once.
67   std::optional<double> total_norm = std::nullopt;
68   if (error_if_nonfinite) {
69     total_norm = total_norm_tensor.item().toDouble();
70     TORCH_CHECK(
71         std::isfinite(*total_norm),
72         "The total norm of order ",
73         norm_type,
74         " for gradients from `parameters` ",
75         "is non-finite, so it cannot be clipped. To disable this error and scale ",
76         "the gradients with the non-finite norm anyway, set ",
77         "`error_if_nonfinite=false`");
78   }
79 
80   auto clip_coef = max_norm / (total_norm_tensor + 1e-6);
81   auto clip_coef_clamped =
82       torch::clamp(clip_coef, std::nullopt /* min */, 1.0 /* max */);
83   for (auto& param : params_with_grad) {
84     param.grad().data().mul_(clip_coef_clamped);
85   }
86 
87   if (!total_norm.has_value()) {
88     total_norm = total_norm_tensor.item().toDouble();
89   }
90   return *total_norm;
91 }
92 
93 // A wrapper around clip_grad_norm_ that allows us to call the function with a
94 // braced-init-list of Tensors.
95 inline double clip_grad_norm_(
96     std::initializer_list<Tensor> parameters,
97     double max_norm,
98     double norm_type = 2.0,
99     bool error_if_nonfinite = false) {
100   return clip_grad_norm_(
101       std::vector<Tensor>(parameters), max_norm, norm_type, error_if_nonfinite);
102 }
103 
104 // A wrapper around clip_grad_norm_ that allows us to call the function with a
105 // single Tensor.
106 inline double clip_grad_norm_(
107     Tensor parameter,
108     double max_norm,
109     double norm_type = 2.0,
110     bool error_if_nonfinite = false) {
111   std::vector<Tensor> params = {std::move(parameter)};
112   return clip_grad_norm_(
113       std::move(params), max_norm, norm_type, error_if_nonfinite);
114 }
115 
116 // Clips gradient of an iterable of parameters at specified value.
117 // Gradients are modified in-place.
118 // See https://pytorch.org/docs/stable/nn.html#clip-grad-value
119 // for more details about this module.
clip_grad_value_(const std::vector<Tensor> & parameters,double clip_value)120 inline void clip_grad_value_(
121     const std::vector<Tensor>& parameters,
122     double clip_value) {
123   for (const auto& param : parameters) {
124     if (param.grad().defined()) {
125       param.grad().data().clamp_(-clip_value, clip_value);
126     }
127   }
128 }
129 
130 // A wrapper around clip_grad_value_ that allows us to call the function with a
131 // braced-init-list of Tensors.
clip_grad_value_(std::initializer_list<Tensor> parameters,double clip_value)132 inline void clip_grad_value_(
133     std::initializer_list<Tensor> parameters,
134     double clip_value) {
135   clip_grad_value_(std::vector<Tensor>(parameters), clip_value);
136 }
137 
138 // A wrapper around clip_grad_value_ that allows us to call the function with a
139 // single Tensor.
clip_grad_value_(Tensor parameter,double clip_value)140 inline void clip_grad_value_(Tensor parameter, double clip_value) {
141   std::vector<Tensor> params = {std::move(parameter)};
142   clip_grad_value_(std::move(params), clip_value);
143 }
144 
145 } // namespace utils
146 } // namespace nn
147 } // namespace torch
148