• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Ceres Solver - A fast non-linear least squares minimizer
2 // Copyright 2012 Google Inc. All rights reserved.
3 // http://code.google.com/p/ceres-solver/
4 //
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are met:
7 //
8 // * Redistributions of source code must retain the above copyright notice,
9 //   this list of conditions and the following disclaimer.
10 // * Redistributions in binary form must reproduce the above copyright notice,
11 //   this list of conditions and the following disclaimer in the documentation
12 //   and/or other materials provided with the distribution.
13 // * Neither the name of Google Inc. nor the names of its contributors may be
14 //   used to endorse or promote products derived from this software without
15 //   specific prior written permission.
16 //
17 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18 // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20 // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21 // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22 // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23 // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24 // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25 // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26 // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27 // POSSIBILITY OF SUCH DAMAGE.
28 //
29 // Author: sameeragarwal@google.com (Sameer Agarwal)
30 //
31 // Generic loop for line search based optimization algorithms.
32 //
33 // This is primarily inpsired by the minFunc packaged written by Mark
34 // Schmidt.
35 //
36 // http://www.di.ens.fr/~mschmidt/Software/minFunc.html
37 //
38 // For details on the theory and implementation see "Numerical
39 // Optimization" by Nocedal & Wright.
40 
41 #include "ceres/line_search_minimizer.h"
42 
43 #include <algorithm>
44 #include <cstdlib>
45 #include <cmath>
46 #include <string>
47 #include <vector>
48 
49 #include "Eigen/Dense"
50 #include "ceres/array_utils.h"
51 #include "ceres/evaluator.h"
52 #include "ceres/internal/eigen.h"
53 #include "ceres/internal/port.h"
54 #include "ceres/internal/scoped_ptr.h"
55 #include "ceres/line_search.h"
56 #include "ceres/line_search_direction.h"
57 #include "ceres/stringprintf.h"
58 #include "ceres/types.h"
59 #include "ceres/wall_time.h"
60 #include "glog/logging.h"
61 
62 namespace ceres {
63 namespace internal {
64 namespace {
65 
66 // TODO(sameeragarwal): I think there is a small bug here, in that if
67 // the evaluation fails, then the state can contain garbage. Look at
68 // this more carefully.
Evaluate(Evaluator * evaluator,const Vector & x,LineSearchMinimizer::State * state,string * message)69 bool Evaluate(Evaluator* evaluator,
70               const Vector& x,
71               LineSearchMinimizer::State* state,
72               string* message) {
73   if (!evaluator->Evaluate(x.data(),
74                            &(state->cost),
75                            NULL,
76                            state->gradient.data(),
77                            NULL)) {
78     *message = "Gradient evaluation failed.";
79     return false;
80   }
81 
82   Vector negative_gradient = -state->gradient;
83   Vector projected_gradient_step(x.size());
84   if (!evaluator->Plus(x.data(),
85                        negative_gradient.data(),
86                        projected_gradient_step.data())) {
87     *message = "projected_gradient_step = Plus(x, -gradient) failed.";
88     return false;
89   }
90 
91   state->gradient_squared_norm = (x - projected_gradient_step).squaredNorm();
92   state->gradient_max_norm =
93       (x - projected_gradient_step).lpNorm<Eigen::Infinity>();
94   return true;
95 }
96 
97 }  // namespace
98 
Minimize(const Minimizer::Options & options,double * parameters,Solver::Summary * summary)99 void LineSearchMinimizer::Minimize(const Minimizer::Options& options,
100                                    double* parameters,
101                                    Solver::Summary* summary) {
102   const bool is_not_silent = !options.is_silent;
103   double start_time = WallTimeInSeconds();
104   double iteration_start_time =  start_time;
105 
106   Evaluator* evaluator = CHECK_NOTNULL(options.evaluator);
107   const int num_parameters = evaluator->NumParameters();
108   const int num_effective_parameters = evaluator->NumEffectiveParameters();
109 
110   summary->termination_type = NO_CONVERGENCE;
111   summary->num_successful_steps = 0;
112   summary->num_unsuccessful_steps = 0;
113 
114   VectorRef x(parameters, num_parameters);
115 
116   State current_state(num_parameters, num_effective_parameters);
117   State previous_state(num_parameters, num_effective_parameters);
118 
119   Vector delta(num_effective_parameters);
120   Vector x_plus_delta(num_parameters);
121 
122   IterationSummary iteration_summary;
123   iteration_summary.iteration = 0;
124   iteration_summary.step_is_valid = false;
125   iteration_summary.step_is_successful = false;
126   iteration_summary.cost_change = 0.0;
127   iteration_summary.gradient_max_norm = 0.0;
128   iteration_summary.gradient_norm = 0.0;
129   iteration_summary.step_norm = 0.0;
130   iteration_summary.linear_solver_iterations = 0;
131   iteration_summary.step_solver_time_in_seconds = 0;
132 
133   // Do initial cost and Jacobian evaluation.
134   if (!Evaluate(evaluator, x, &current_state, &summary->message)) {
135     summary->termination_type = FAILURE;
136     summary->message = "Initial cost and jacobian evaluation failed. "
137         "More details: " + summary->message;
138     LOG_IF(WARNING, is_not_silent) << "Terminating: " << summary->message;
139     return;
140   }
141 
142   summary->initial_cost = current_state.cost + summary->fixed_cost;
143   iteration_summary.cost = current_state.cost + summary->fixed_cost;
144 
145   iteration_summary.gradient_max_norm = current_state.gradient_max_norm;
146   iteration_summary.gradient_norm = sqrt(current_state.gradient_squared_norm);
147 
148   if (iteration_summary.gradient_max_norm <= options.gradient_tolerance) {
149     summary->message = StringPrintf("Gradient tolerance reached. "
150                                     "Gradient max norm: %e <= %e",
151                                     iteration_summary.gradient_max_norm,
152                                     options.gradient_tolerance);
153     summary->termination_type = CONVERGENCE;
154     VLOG_IF(1, is_not_silent) << "Terminating: " << summary->message;
155     return;
156   }
157 
158   iteration_summary.iteration_time_in_seconds =
159       WallTimeInSeconds() - iteration_start_time;
160   iteration_summary.cumulative_time_in_seconds =
161       WallTimeInSeconds() - start_time
162       + summary->preprocessor_time_in_seconds;
163   summary->iterations.push_back(iteration_summary);
164 
165   LineSearchDirection::Options line_search_direction_options;
166   line_search_direction_options.num_parameters = num_effective_parameters;
167   line_search_direction_options.type = options.line_search_direction_type;
168   line_search_direction_options.nonlinear_conjugate_gradient_type =
169       options.nonlinear_conjugate_gradient_type;
170   line_search_direction_options.max_lbfgs_rank = options.max_lbfgs_rank;
171   line_search_direction_options.use_approximate_eigenvalue_bfgs_scaling =
172       options.use_approximate_eigenvalue_bfgs_scaling;
173   scoped_ptr<LineSearchDirection> line_search_direction(
174       LineSearchDirection::Create(line_search_direction_options));
175 
176   LineSearchFunction line_search_function(evaluator);
177 
178   LineSearch::Options line_search_options;
179   line_search_options.interpolation_type =
180       options.line_search_interpolation_type;
181   line_search_options.min_step_size = options.min_line_search_step_size;
182   line_search_options.sufficient_decrease =
183       options.line_search_sufficient_function_decrease;
184   line_search_options.max_step_contraction =
185       options.max_line_search_step_contraction;
186   line_search_options.min_step_contraction =
187       options.min_line_search_step_contraction;
188   line_search_options.max_num_iterations =
189       options.max_num_line_search_step_size_iterations;
190   line_search_options.sufficient_curvature_decrease =
191       options.line_search_sufficient_curvature_decrease;
192   line_search_options.max_step_expansion =
193       options.max_line_search_step_expansion;
194   line_search_options.function = &line_search_function;
195 
196   scoped_ptr<LineSearch>
197       line_search(LineSearch::Create(options.line_search_type,
198                                      line_search_options,
199                                      &summary->message));
200   if (line_search.get() == NULL) {
201     summary->termination_type = FAILURE;
202     LOG_IF(ERROR, is_not_silent) << "Terminating: " << summary->message;
203     return;
204   }
205 
206   LineSearch::Summary line_search_summary;
207   int num_line_search_direction_restarts = 0;
208 
209   while (true) {
210     if (!RunCallbacks(options, iteration_summary, summary)) {
211       break;
212     }
213 
214     iteration_start_time = WallTimeInSeconds();
215     if (iteration_summary.iteration >= options.max_num_iterations) {
216       summary->message = "Maximum number of iterations reached.";
217       summary->termination_type = NO_CONVERGENCE;
218       VLOG_IF(1, is_not_silent) << "Terminating: " << summary->message;
219       break;
220     }
221 
222     const double total_solver_time = iteration_start_time - start_time +
223         summary->preprocessor_time_in_seconds;
224     if (total_solver_time >= options.max_solver_time_in_seconds) {
225       summary->message = "Maximum solver time reached.";
226       summary->termination_type = NO_CONVERGENCE;
227       VLOG_IF(1, is_not_silent) << "Terminating: " << summary->message;
228       break;
229     }
230 
231     iteration_summary = IterationSummary();
232     iteration_summary.iteration = summary->iterations.back().iteration + 1;
233     iteration_summary.step_is_valid = false;
234     iteration_summary.step_is_successful = false;
235 
236     bool line_search_status = true;
237     if (iteration_summary.iteration == 1) {
238       current_state.search_direction = -current_state.gradient;
239     } else {
240       line_search_status = line_search_direction->NextDirection(
241           previous_state,
242           current_state,
243           &current_state.search_direction);
244     }
245 
246     if (!line_search_status &&
247         num_line_search_direction_restarts >=
248         options.max_num_line_search_direction_restarts) {
249       // Line search direction failed to generate a new direction, and we
250       // have already reached our specified maximum number of restarts,
251       // terminate optimization.
252       summary->message =
253           StringPrintf("Line search direction failure: specified "
254                        "max_num_line_search_direction_restarts: %d reached.",
255                        options.max_num_line_search_direction_restarts);
256       summary->termination_type = FAILURE;
257       LOG_IF(WARNING, is_not_silent) << "Terminating: " << summary->message;
258       break;
259     } else if (!line_search_status) {
260       // Restart line search direction with gradient descent on first iteration
261       // as we have not yet reached our maximum number of restarts.
262       CHECK_LT(num_line_search_direction_restarts,
263                options.max_num_line_search_direction_restarts);
264 
265       ++num_line_search_direction_restarts;
266       LOG_IF(WARNING, is_not_silent)
267           << "Line search direction algorithm: "
268           << LineSearchDirectionTypeToString(
269               options.line_search_direction_type)
270           << ", failed to produce a valid new direction at "
271           << "iteration: " << iteration_summary.iteration
272           << ". Restarting, number of restarts: "
273           << num_line_search_direction_restarts << " / "
274           << options.max_num_line_search_direction_restarts
275           << " [max].";
276       line_search_direction.reset(
277           LineSearchDirection::Create(line_search_direction_options));
278       current_state.search_direction = -current_state.gradient;
279     }
280 
281     line_search_function.Init(x, current_state.search_direction);
282     current_state.directional_derivative =
283         current_state.gradient.dot(current_state.search_direction);
284 
285     // TODO(sameeragarwal): Refactor this into its own object and add
286     // explanations for the various choices.
287     //
288     // Note that we use !line_search_status to ensure that we treat cases when
289     // we restarted the line search direction equivalently to the first
290     // iteration.
291     const double initial_step_size =
292         (iteration_summary.iteration == 1 || !line_search_status)
293         ? min(1.0, 1.0 / current_state.gradient_max_norm)
294         : min(1.0, 2.0 * (current_state.cost - previous_state.cost) /
295               current_state.directional_derivative);
296     // By definition, we should only ever go forwards along the specified search
297     // direction in a line search, most likely cause for this being violated
298     // would be a numerical failure in the line search direction calculation.
299     if (initial_step_size < 0.0) {
300       summary->message =
301           StringPrintf("Numerical failure in line search, initial_step_size is "
302                        "negative: %.5e, directional_derivative: %.5e, "
303                        "(current_cost - previous_cost): %.5e",
304                        initial_step_size, current_state.directional_derivative,
305                        (current_state.cost - previous_state.cost));
306       summary->termination_type = FAILURE;
307       LOG_IF(WARNING, is_not_silent) << "Terminating: " << summary->message;
308       break;
309     }
310 
311     line_search->Search(initial_step_size,
312                         current_state.cost,
313                         current_state.directional_derivative,
314                         &line_search_summary);
315     if (!line_search_summary.success) {
316       summary->message =
317           StringPrintf("Numerical failure in line search, failed to find "
318                        "a valid step size, (did not run out of iterations) "
319                        "using initial_step_size: %.5e, initial_cost: %.5e, "
320                        "initial_gradient: %.5e.",
321                        initial_step_size, current_state.cost,
322                        current_state.directional_derivative);
323       LOG_IF(WARNING, is_not_silent) << "Terminating: " << summary->message;
324       summary->termination_type = FAILURE;
325       break;
326     }
327 
328     current_state.step_size = line_search_summary.optimal_step_size;
329     delta = current_state.step_size * current_state.search_direction;
330 
331     previous_state = current_state;
332     iteration_summary.step_solver_time_in_seconds =
333         WallTimeInSeconds() - iteration_start_time;
334 
335     if (!evaluator->Plus(x.data(), delta.data(), x_plus_delta.data())) {
336       summary->termination_type = FAILURE;
337       summary->message =
338           "x_plus_delta = Plus(x, delta) failed. This should not happen "
339           "as the step was valid when it was selected by the line search.";
340       LOG_IF(WARNING, is_not_silent) << "Terminating: " << summary->message;
341       break;
342     } else if (!Evaluate(evaluator,
343                          x_plus_delta,
344                          &current_state,
345                          &summary->message)) {
346       summary->termination_type = FAILURE;
347       summary->message =
348           "Step failed to evaluate. This should not happen as the step was "
349           "valid when it was selected by the line search. More details: " +
350           summary->message;
351       LOG_IF(WARNING, is_not_silent) << "Terminating: " << summary->message;
352       break;
353     } else {
354       x = x_plus_delta;
355     }
356 
357     iteration_summary.gradient_max_norm = current_state.gradient_max_norm;
358     iteration_summary.gradient_norm = sqrt(current_state.gradient_squared_norm);
359     iteration_summary.cost_change = previous_state.cost - current_state.cost;
360     iteration_summary.cost = current_state.cost + summary->fixed_cost;
361     iteration_summary.step_norm = delta.norm();
362     iteration_summary.step_is_valid = true;
363     iteration_summary.step_is_successful = true;
364     iteration_summary.step_norm = delta.norm();
365     iteration_summary.step_size =  current_state.step_size;
366     iteration_summary.line_search_function_evaluations =
367         line_search_summary.num_function_evaluations;
368     iteration_summary.line_search_gradient_evaluations =
369         line_search_summary.num_gradient_evaluations;
370     iteration_summary.line_search_iterations =
371         line_search_summary.num_iterations;
372     iteration_summary.iteration_time_in_seconds =
373         WallTimeInSeconds() - iteration_start_time;
374     iteration_summary.cumulative_time_in_seconds =
375         WallTimeInSeconds() - start_time
376         + summary->preprocessor_time_in_seconds;
377 
378     summary->iterations.push_back(iteration_summary);
379     ++summary->num_successful_steps;
380 
381     if (iteration_summary.gradient_max_norm <= options.gradient_tolerance) {
382       summary->message = StringPrintf("Gradient tolerance reached. "
383                                       "Gradient max norm: %e <= %e",
384                                       iteration_summary.gradient_max_norm,
385                                       options.gradient_tolerance);
386       summary->termination_type = CONVERGENCE;
387       VLOG_IF(1, is_not_silent) << "Terminating: " << summary->message;
388       break;
389     }
390 
391     const double absolute_function_tolerance =
392         options.function_tolerance * previous_state.cost;
393     if (fabs(iteration_summary.cost_change) < absolute_function_tolerance) {
394       summary->message =
395           StringPrintf("Function tolerance reached. "
396                        "|cost_change|/cost: %e <= %e",
397                        fabs(iteration_summary.cost_change) /
398                        previous_state.cost,
399                        options.function_tolerance);
400       summary->termination_type = CONVERGENCE;
401       VLOG_IF(1, is_not_silent) << "Terminating: " << summary->message;
402       break;
403     }
404   }
405 }
406 
407 }  // namespace internal
408 }  // namespace ceres
409