#include #include #include #include #include #include #include #include #include #include #include namespace torch { namespace optim { LBFGSOptions::LBFGSOptions(double lr) : lr_(lr) {} bool operator==(const LBFGSOptions& lhs, const LBFGSOptions& rhs) { return (lhs.lr() == rhs.lr()) && (lhs.max_iter() == rhs.max_iter()) && (lhs.max_eval() == rhs.max_eval()) && (lhs.tolerance_grad() == rhs.tolerance_grad()) && (lhs.tolerance_change() == rhs.tolerance_change() && (lhs.history_size() == rhs.history_size())) && (lhs.line_search_fn() == rhs.line_search_fn()); } void LBFGSOptions::serialize(torch::serialize::OutputArchive& archive) const { _TORCH_OPTIM_SERIALIZE_TORCH_ARG(lr); _TORCH_OPTIM_SERIALIZE_TORCH_ARG(max_iter); _TORCH_OPTIM_SERIALIZE_TORCH_ARG(max_eval); _TORCH_OPTIM_SERIALIZE_TORCH_ARG(tolerance_grad); _TORCH_OPTIM_SERIALIZE_TORCH_ARG(tolerance_change); _TORCH_OPTIM_SERIALIZE_TORCH_ARG(history_size); _TORCH_OPTIM_SERIALIZE_TORCH_ARG(line_search_fn); } void LBFGSOptions::serialize(torch::serialize::InputArchive& archive) { _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, lr); _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(int64_t, max_iter); _TORCH_OPTIM_DESERIALIZE_TORCH_ARG_OPTIONAL(int64_t, max_eval); _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, tolerance_grad); _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, tolerance_change); _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(int64_t, history_size); _TORCH_OPTIM_DESERIALIZE_TORCH_ARG_OPTIONAL(std::string, line_search_fn); } double LBFGSOptions::get_lr() const { return lr(); } void LBFGSOptions::set_lr(const double lr) { this->lr(lr); } template bool if_container_equal(T lhs, T rhs) { if (!(lhs.size() == rhs.size())) return false; for (const auto i : c10::irange(lhs.size())) { if (!torch::equal(lhs.at(i), rhs.at(i))) return false; } return true; } bool operator==(const LBFGSParamState& lhs, const LBFGSParamState& rhs) { auto isNull = [](const std::optional>& val) { return val == std::nullopt; }; return (lhs.func_evals() == rhs.func_evals()) && (lhs.n_iter() == rhs.n_iter()) && (lhs.t() == rhs.t()) && (lhs.prev_loss() == rhs.prev_loss()) && torch::equal_if_defined(lhs.d(), rhs.d()) && torch::equal_if_defined(lhs.H_diag(), rhs.H_diag()) && torch::equal_if_defined(lhs.prev_flat_grad(), rhs.prev_flat_grad()) && if_container_equal(lhs.old_dirs(), rhs.old_dirs()) && if_container_equal(lhs.old_stps(), rhs.old_stps()) && if_container_equal(lhs.ro(), rhs.ro()) && ((isNull(lhs.al()) && isNull(rhs.al())) || (!isNull(lhs.al()) && !isNull(rhs.al()) && if_container_equal(*lhs.al(), *rhs.al()))); } void LBFGSParamState::serialize( torch::serialize::OutputArchive& archive) const { _TORCH_OPTIM_SERIALIZE_TORCH_ARG(func_evals); _TORCH_OPTIM_SERIALIZE_TORCH_ARG(n_iter); _TORCH_OPTIM_SERIALIZE_TORCH_ARG(t); _TORCH_OPTIM_SERIALIZE_TORCH_ARG(prev_loss); _TORCH_OPTIM_SERIALIZE_TORCH_ARG(d); _TORCH_OPTIM_SERIALIZE_TORCH_ARG(H_diag); _TORCH_OPTIM_SERIALIZE_TORCH_ARG(prev_flat_grad); _TORCH_OPTIM_SERIALIZE_TORCH_ARG_DEQUE(old_dirs); _TORCH_OPTIM_SERIALIZE_TORCH_ARG_DEQUE(old_stps); _TORCH_OPTIM_SERIALIZE_TORCH_ARG_DEQUE(ro); // Python version only serializes state vars if explicitly defined if (al() != std::nullopt) { _TORCH_OPTIM_SERIALIZE_TORCH_ARG(al); } } void LBFGSParamState::serialize(torch::serialize::InputArchive& archive) { _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(int64_t, func_evals); _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(int64_t, n_iter); _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, t); _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, prev_loss); _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, d); _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, H_diag); _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(Tensor, prev_flat_grad); _TORCH_OPTIM_DESERIALIZE_TORCH_ARG_DEQUE(std::deque, old_dirs); _TORCH_OPTIM_DESERIALIZE_TORCH_ARG_DEQUE(std::deque, old_stps); _TORCH_OPTIM_DESERIALIZE_TORCH_ARG_DEQUE(std::deque, ro); _TORCH_OPTIM_DESERIALIZE_TORCH_ARG_OPTIONAL(std::vector, al); } Tensor LBFGS::_gather_flat_grad() { std::vector views; for (const auto& p : param_groups_.at(0).params()) { if (!p.grad().defined()) { views.emplace_back(p.new_empty({p.numel()}).zero_()); } else if (p.grad().is_sparse()) { views.emplace_back(p.grad().to_dense().view(-1)); } else { views.emplace_back(p.grad().view(-1)); } } return torch::cat(views, 0); } int64_t LBFGS::_numel() { if (_numel_cache == std::nullopt) { auto res = 0; for (const auto& p : param_groups_.at(0).params()) { res += p.numel(); } _numel_cache = res; } return *_numel_cache; } void LBFGS::_add_grad(const double step_size, const Tensor& update) { auto offset = 0; for (auto& p : param_groups_.at(0).params()) { auto numel = p.numel(); // view as to avoid deprecated pointwise semantics p.add_( update.index({at::indexing::Slice(offset, offset + numel)}).view_as(p), step_size); offset += numel; } TORCH_INTERNAL_ASSERT(offset == _numel()); } void LBFGS::_set_param(const std::vector& params_data) { auto& _params = param_groups_.at(0).params(); TORCH_INTERNAL_ASSERT(params_data.size() == _params.size()); for (const auto i : c10::irange(_params.size())) { _params.at(i).copy_(params_data.at(i)); } } std::vector LBFGS::_clone_param() { std::vector result; for (const auto& p : param_groups_.at(0).params()) { result.emplace_back(p.clone(at::MemoryFormat::Contiguous)); } return result; } std::tuple LBFGS::_directional_evaluate( const LossClosure& closure, const std::vector& x, double t, const Tensor& d) { _add_grad(t, d); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) double loss; { torch::AutoGradMode enable_grad(true); loss = closure().item(); } auto flat_grad = _gather_flat_grad(); _set_param(x); return std::make_tuple(loss, flat_grad); } static double _cubic_interpolate( double x1, double f1, double g1, double x2, double f2, double g2, std::optional> bounds = std::nullopt) { // ported from https://github.com/torch/optim/blob/master/polyinterp.lua // Compute bounds of interpolation area // NOLINTNEXTLINE(cppcoreguidelines-init-variables) double xmin_bound, xmax_bound; if (bounds != std::nullopt) { std::tie(xmin_bound, xmax_bound) = *bounds; } else { std::tie(xmin_bound, xmax_bound) = (x1 <= x2) ? std::make_tuple(x1, x2) : std::make_tuple(x2, x1); } // Code for most common case: cubic interpolation of 2 points // w/ function and derivative values for both // Solution in this case (where x2 is the farthest point): // d1 = g1 + g2 - 3*(f1-f2)/(x1-x2); // d2 = sqrt(d1^2 - g1*g2); // min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2)); // t_new = min(max(min_pos,xmin_bound),xmax_bound); auto d1 = (g1 + g2) - (3 * (f1 - f2) / (x1 - x2)); auto d2_square = std::pow(d1, 2) - g1 * g2; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) double d2; if (d2_square >= 0) { d2 = std::sqrt(d2_square); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) double min_pos; if (x1 <= x2) { min_pos = x2 - ((x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2))); } else { min_pos = x1 - ((x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2))); } return std::min(std::max(min_pos, xmin_bound), xmax_bound); } else { return (xmin_bound + xmax_bound) / 2; } } using Function = std::function( const std::vector& x, double t, const Tensor& d)>; static std::tuple _strong_wolfe( const Function& obj_func, const std::vector& x, double t, const Tensor& d, double f, Tensor g, const Tensor& gtd, double c1 = 1e-4, double c2 = 0.9, // // NOLINT(cppcoreguidelines-avoid-magic-numbers) double tolerance_change = 1e-9, double max_ls = 25) { // NOLINT(cppcoreguidelines-avoid-magic-numbers) auto val = [](const Tensor& t) { return t.item(); }; auto d_norm = val(d.abs().max()); g = g.clone(at::MemoryFormat::Contiguous); // evaluate objective and gradient using initial step auto [f_new, g_new] = obj_func(x, t, d); int64_t ls_func_evals = 1; auto gtd_new = g_new.dot(d); // bracket an interval containing a point satisfying the Wolfe criteria double t_prev = 0; auto f_prev = f; auto g_prev = g; auto gtd_prev = gtd; bool done = false; auto ls_iter = 0; std::vector bracket, bracket_f; std::vector bracket_g, bracket_gtd; while (ls_iter < max_ls) { // check conditions if ((f_new > (f + c1 * t * val(gtd))) || (ls_iter > 1 && (f_new >= f_prev))) { bracket = {t_prev, t}; bracket_f = {f_prev, f_new}; bracket_g = {g_prev, g_new.clone(at::MemoryFormat::Contiguous)}; bracket_gtd = {gtd_prev, gtd_new}; break; } if (std::abs(val(gtd_new)) <= (-c2 * val(gtd))) { bracket = {t, t}; bracket_f = {f_new, f_new}; bracket_g = {g_new, g_new}; done = true; break; } if (val(gtd_new) >= 0) { bracket = {t_prev, t}; bracket_f = {f_prev, f_new}; bracket_g = {g_prev, g_new.clone(at::MemoryFormat::Contiguous)}; bracket_gtd = {gtd_prev, gtd_new}; break; } // interpolate auto min_step = t + 0.01 * (t - t_prev); // NOLINT(cppcoreguidelines-avoid-magic-numbers) auto max_step = t * 10; // NOLINT(cppcoreguidelines-avoid-magic-numbers) auto tmp = t; t = _cubic_interpolate( t_prev, f_prev, val(gtd_prev), t, f_new, val(gtd_new), std::make_tuple(min_step, max_step)); // next step t_prev = tmp; f_prev = f_new; g_prev = g_new.clone(at::MemoryFormat::Contiguous); gtd_prev = gtd_new; std::tie(f_new, g_new) = obj_func(x, t, d); ls_func_evals += 1; gtd_new = g_new.dot(d); ls_iter += 1; } // reached max number of iterations? if (ls_iter == max_ls) { bracket = {0, t}; bracket_f = {f, f_new}; bracket_g = {g, g_new}; } // zoom phase: we now have a point satisfying the criteria, or // a bracket around it. We refine the bracket until we find the // exact point satisfying the criteria bool insuf_progress = false; // find high and low points in bracket auto [low_pos, high_pos] = bracket_f[0] <= bracket_f[1] ? std::make_tuple(0, 1) : std::make_tuple(1, 0); while (!done && (ls_iter < max_ls)) { // compute new trial value t = _cubic_interpolate( bracket[0], bracket_f[0], val(bracket_gtd[0]), bracket[1], bracket_f[1], val(bracket_gtd[1])); // test that we are making sufficient progress: // in case `t` is so close to boundary, we mark that we are making // insufficient progress, and if // + we have made insufficient progress in the last step, or // + `t` is at one of the boundary, // we will move `t` to a position which is `0.1 * len(bracket)` // away from the nearest boundary point. double bracket_max = std::max(bracket[0], bracket[1]); auto bracket_min = std::min(bracket[0], bracket[1]); auto eps = 0.1 * (bracket_max - bracket_min); // // NOLINT(cppcoreguidelines-avoid-magic-numbers) if (std::min(bracket_max - t, t - bracket_min) < eps) { // interpolation close to boundary if (insuf_progress || (t >= bracket_max) || (t <= bracket_min)) { // evaluate at 0.1 away from boundary t = (std::abs(t - bracket_max) < std::abs(t - bracket_min)) ? bracket_max - eps : bracket_min + eps; insuf_progress = false; } else { insuf_progress = true; } } else { insuf_progress = false; } // Evaluate new point std::tie(f_new, g_new) = obj_func(x, t, d); ls_func_evals += 1; gtd_new = g_new.dot(d); ls_iter += 1; if ((f_new > (f + c1 * t * val(gtd))) || (f_new >= bracket_f[low_pos])) { // Armijo condition not satisfied or not lower than lowest point // # Armijo condition not satisfied or not lower than lowest point bracket[high_pos] = t; bracket_f[high_pos] = f_new; bracket_g[high_pos] = g_new.clone(at::MemoryFormat::Contiguous); bracket_gtd[high_pos] = gtd_new; std::tie(low_pos, high_pos) = bracket_f[0] <= bracket_f[1] ? std::make_tuple(0, 1) : std::make_tuple(1, 0); } else { if (val(at::abs(gtd_new)) <= (-c2 * val(gtd))) { // Wolfe conditions satisfied done = true; } else if ((val(gtd_new) * (bracket[high_pos] - bracket[low_pos])) >= 0) { // old high becomes new low bracket[high_pos] = bracket[low_pos]; bracket_f[high_pos] = bracket_f[low_pos]; bracket_g[high_pos] = bracket_g[low_pos]; bracket_gtd[high_pos] = bracket_gtd[low_pos]; } // new point becomes new low bracket[low_pos] = t; bracket_f[low_pos] = f_new; bracket_g[low_pos] = g_new.clone(at::MemoryFormat::Contiguous); bracket_gtd[low_pos] = gtd_new; } // line-search bracket is so small if ((std::abs(bracket[1] - bracket[0]) * d_norm) < tolerance_change) break; } // return stuff t = bracket[low_pos]; f_new = bracket_f[low_pos]; g_new = bracket_g[low_pos]; return std::make_tuple(f_new, g_new, t, ls_func_evals); } Tensor LBFGS::step(LossClosure closure) { NoGradGuard no_grad; TORCH_CHECK(closure != nullptr, "LBFGS requires a closure function"); TORCH_INTERNAL_ASSERT(param_groups_.size() == 1); auto val = [](const Tensor& t) { return t.item(); }; auto& group = param_groups_.at(0); auto& _params = group.params(); const auto& options = static_cast(group.options()); auto lr = options.lr(); auto max_iter = options.max_iter(); auto max_eval = options.max_eval(); auto tolerance_grad = options.tolerance_grad(); auto tolerance_change = options.tolerance_change(); auto line_search_fn = options.line_search_fn(); auto history_size = options.history_size(); // NOTE: LBFGS has only global state, but we register it as state for // the first param, because this helps with casting in load_state_dict auto param_state = state_.find(_params.at(0).unsafeGetTensorImpl()); if (param_state == state_.end()) { state_[_params.at(0).unsafeGetTensorImpl()] = std::make_unique(); } auto& state = static_cast( *state_[_params.at(0).unsafeGetTensorImpl()]); // evaluate initial f(x) and df/dx Tensor orig_loss; { torch::AutoGradMode enable_grad(true); orig_loss = closure(); } auto loss = val(orig_loss); auto current_evals = 1; state.func_evals(state.func_evals() + 1); auto flat_grad = _gather_flat_grad(); auto opt_cond = (val(flat_grad.abs().max()) <= tolerance_grad); // optimal condition if (opt_cond) { return orig_loss; } // tensors cached in state (for tracing) auto& d = state.d(); auto& t = state.t(); auto& old_dirs = state.old_dirs(); auto& old_stps = state.old_stps(); auto& ro = state.ro(); auto& H_diag = state.H_diag(); auto& prev_flat_grad = state.prev_flat_grad(); auto& prev_loss = state.prev_loss(); int n_iter = 0; // optimize for a max of max_iter iterations while (n_iter < max_iter) { // keep track of nb of iterations n_iter += 1; state.n_iter(state.n_iter() + 1); // compute gradient descent direction if (state.n_iter() == 1) { d = flat_grad.neg(); H_diag = torch::tensor(1); old_dirs = {}; old_stps = {}; ro = {}; } else { // do lbfgs update (update memory) auto y = flat_grad.sub(prev_flat_grad); auto s = d.mul(t); auto ys = y.dot(s); // y*s if (val(ys) > 1e-10) { // NOLINT(cppcoreguidelines-avoid-magic-numbers) // updating memory if (static_cast(old_dirs.size()) == history_size) { // shift history by one (limited-memory) old_dirs.pop_front(); old_stps.pop_front(); ro.pop_front(); } // store new direction/step old_dirs.emplace_back(y); old_stps.emplace_back(s); ro.emplace_back(1. / ys); // update scale of initial Hessian approximation H_diag = ys / y.dot(y); // (y*y) } // compute the approximate (L-BFGS) inverse Hessian // multiplied by the gradient int64_t num_old = static_cast(old_dirs.size()); if (state.al() == std::nullopt) { state.al(std::vector(history_size)); } auto& al = state.al(); // iteration in L-BFGS loop collapsed to use just one buffer auto q = flat_grad.neg(); for (int64_t i = num_old - 1; i > -1; i--) { (*al).at(i) = old_stps.at(i).dot(q) * ro.at(i); q.add_(old_dirs.at(i), -val((*al).at(i))); } // multiply by initial Hessian // r/d is the final direction auto r = torch::mul(q, H_diag); d = r; for (const auto i : c10::irange(num_old)) { auto be_i = old_dirs.at(i).dot(r) * ro.at(i); r.add_(old_stps.at(i), val((*al).at(i) - be_i)); } } if (!prev_flat_grad.defined()) { prev_flat_grad = flat_grad.clone(at::MemoryFormat::Contiguous); } else { prev_flat_grad.copy_(flat_grad); } prev_loss = loss; // ############################################################ // # compute step length // ############################################################ // reset initial guess for step size if (state.n_iter() == 1) { t = std::min(1., 1. / val(flat_grad.abs().sum())) * lr; } else { t = lr; } // directional derivative auto gtd = flat_grad.dot(d); // g * d // directional derivative is below tolerance if (val(gtd) > -tolerance_change) break; // optional line search: user function auto ls_func_evals = 0; if (line_search_fn != std::nullopt) { TORCH_CHECK( *line_search_fn == "strong_wolfe", "only 'strong_wolfe' is supported"); auto x_init = _clone_param(); auto obj_func = [&](const std::vector& x, double t, const Tensor& d) { return _directional_evaluate(closure, x, t, d); }; std::tie(loss, flat_grad, t, ls_func_evals) = _strong_wolfe(obj_func, x_init, t, d, loss, flat_grad, gtd); _add_grad(t, d); opt_cond = (val(flat_grad.abs().max()) <= tolerance_grad); } else { // no line search, simply move with fixed-step _add_grad(t, d); if (n_iter != max_iter) { // re-evaluate function only if not in last iteration // the reason we do this: in a stochastic setting, // no use to re-evaluate that function here { torch::AutoGradMode enable_grad(true); loss = val(closure()); } flat_grad = _gather_flat_grad(); opt_cond = val(torch::max(flat_grad.abs())) <= tolerance_grad; ls_func_evals = 1; } } // update func eval current_evals += ls_func_evals; state.func_evals(state.func_evals() + ls_func_evals); // ############################################################ // # check conditions // ############################################################ if (n_iter == max_iter) break; if (current_evals >= *max_eval) break; // optimal condition if (opt_cond) break; // lack of progress if (val(d.mul(t).abs().max()) <= tolerance_change) break; if (std::abs(loss - prev_loss) < tolerance_change) break; } return orig_loss; } void LBFGS::save(serialize::OutputArchive& archive) const { serialize(*this, archive); } void LBFGS::load(serialize::InputArchive& archive) { IValue pytorch_version; if (archive.try_read("pytorch_version", pytorch_version)) { serialize(*this, archive); } else { // deserializing archives saved in old format (prior to // version 1.5.0) TORCH_WARN( "Your serialized LBFGS optimizer is still using the old serialization format. " "The func_evals and n_iter value in state will be set to 0, ro will be set to an empty deque " "and al will be set to std::nullopt because the old LBFGS optimizer didn't save these values." "You should re-save your LBFGS optimizer to use the new serialization format."); Tensor d, t, H_diag, prev_flat_grad, prev_loss; std::deque old_dirs, old_stps; archive("d", d, /*is_buffer=*/true); archive("t", t, /*is_buffer=*/true); archive("H_diag", H_diag, /*is_buffer=*/true); archive("prev_flat_grad", prev_flat_grad, /*is_buffer=*/true); archive("prev_loss", prev_loss, /*is_buffer=*/true); torch::optim::serialize(archive, "old_dirs", old_dirs); torch::optim::serialize(archive, "old_stps", old_stps); // NOTE: LBFGS has only global state, but we register it as state for // the first param, because this helps with casting in load_state_dict auto state = std::make_unique(); state->d(d); state->t(t.item()); state->H_diag(H_diag); state->prev_flat_grad(prev_flat_grad); state->prev_loss(prev_loss.item()); state->old_dirs(old_dirs); state->old_stps(old_stps); state_[param_groups_.at(0).params().at(0).unsafeGetTensorImpl()] = std::move(state); } } } // namespace optim } // namespace torch