00001 #ifndef SCITBX_LBFGS_H
00002 #define SCITBX_LBFGS_H
00003
00004 #include <stdio.h>
00005 #include <cstddef>
00006 #include <cmath>
00007 #include <stdexcept>
00008 #include <algorithm>
00009 #include <vector>
00010 #include <string>
00011
00012 namespace scitbx {
00013
00015
00052 namespace lbfgs {
00053
00055
00057 class error : public std::exception {
00058 public:
00060 error(std::string const& msg) throw()
00061 : msg_("lbfgs error: " + msg)
00062 {}
00064 virtual const char* what() const throw() { return msg_.c_str(); }
00065 protected:
00066 virtual ~error() throw() {}
00067 std::string msg_;
00068 #ifndef DOXYGEN_SHOULD_SKIP_THIS
00069 public:
00070 static std::string itoa(unsigned long i) {
00071 char buf[80];
00072 sprintf(buf, "%lu", i);
00073 return std::string(buf);
00074 }
00075 #endif
00076 };
00077
00079 class error_internal_error : public error {
00080 public:
00082 error_internal_error(const char* file, unsigned long line) throw()
00083 : error(
00084 "Internal Error: " + std::string(file) + "(" + itoa(line) + ")")
00085 {}
00086 };
00087
00089 class error_improper_input_parameter : public error {
00090 public:
00092 error_improper_input_parameter(std::string const& msg) throw()
00093 : error("Improper input parameter: " + msg)
00094 {}
00095 };
00096
00098 class error_improper_input_data : public error {
00099 public:
00101 error_improper_input_data(std::string const& msg) throw()
00102 : error("Improper input data: " + msg)
00103 {}
00104 };
00105
00107 class error_search_direction_not_descent : public error {
00108 public:
00110 error_search_direction_not_descent() throw()
00111 : error("The search direction is not a descent direction.")
00112 {}
00113 };
00114
00116 class error_line_search_failed : public error {
00117 public:
00119 error_line_search_failed(std::string const& msg) throw()
00120 : error("Line search failed: " + msg)
00121 {}
00122 };
00123
00125 class error_line_search_failed_rounding_errors
00126 : public error_line_search_failed {
00127 public:
00129 error_line_search_failed_rounding_errors(std::string const& msg) throw()
00130 : error_line_search_failed(msg)
00131 {}
00132 };
00133
00134 namespace detail {
00135
00136 template <typename NumType>
00137 inline
00138 NumType
00139 pow2(NumType const& x) { return x * x; }
00140
00141 template <typename NumType>
00142 inline
00143 NumType
00144 abs(NumType const& x) {
00145 if (x < NumType(0)) return -x;
00146 return x;
00147 }
00148
00149
00150 template <typename FloatType, typename SizeType = std::size_t>
00151 class mcsrch
00152 {
00153 protected:
00154 int infoc;
00155 FloatType dginit;
00156 bool brackt;
00157 bool stage1;
00158 FloatType finit;
00159 FloatType dgtest;
00160 FloatType width;
00161 FloatType width1;
00162 FloatType stx;
00163 FloatType fx;
00164 FloatType dgx;
00165 FloatType sty;
00166 FloatType fy;
00167 FloatType dgy;
00168 FloatType stmin;
00169 FloatType stmax;
00170
00171 static FloatType const& max3(
00172 FloatType const& x,
00173 FloatType const& y,
00174 FloatType const& z)
00175 {
00176 return x < y ? (y < z ? z : y ) : (x < z ? z : x );
00177 }
00178
00179 public:
00180
00181
00182
00183
00184
00185
00186
00187
00188
00189
00190
00191
00192
00193
00194
00195
00196
00197
00198
00199
00200
00201
00202
00203
00204
00205
00206
00207
00208
00209
00210
00211
00212
00213
00214
00215
00216
00217
00218
00219
00220
00221
00222
00223
00224
00225
00226
00227
00228
00229
00230
00231
00232
00233
00234
00235
00236
00237
00238
00239
00240
00241
00242
00243
00244
00245
00246
00247
00248
00249
00250
00251
00252
00253
00254
00255
00256
00257
00258
00259
00260
00261
00262
00263
00264
00265
00266
00267
00268
00269
00270
00271
00272
00273 void run(
00274 FloatType const& gtol,
00275 FloatType const& stpmin,
00276 FloatType const& stpmax,
00277 SizeType n,
00278 FloatType* x,
00279 FloatType f,
00280 const FloatType* g,
00281 FloatType* s,
00282 SizeType is0,
00283 FloatType& stp,
00284 FloatType ftol,
00285 FloatType xtol,
00286 SizeType maxfev,
00287 int& info,
00288 SizeType& nfev,
00289 FloatType* wa);
00290
00291
00292
00293
00294
00295
00296
00297
00298
00299
00300
00301
00302
00303
00304
00305
00306
00307
00308
00309
00310
00311
00312
00313
00314
00315
00316
00317
00318
00319
00320
00321
00322
00323
00324
00325
00326
00327
00328
00329
00330
00331
00332
00333
00334
00335
00336
00337
00338
00339
00340
00341
00342
00343
00344
00345
00346
00347
00348
00349
00350 static int mcstep(
00351 FloatType& stx,
00352 FloatType& fx,
00353 FloatType& dx,
00354 FloatType& sty,
00355 FloatType& fy,
00356 FloatType& dy,
00357 FloatType& stp,
00358 FloatType fp,
00359 FloatType dp,
00360 bool& brackt,
00361 FloatType stpmin,
00362 FloatType stpmax);
00363 };
00364
00365 template <typename FloatType, typename SizeType>
00366 void mcsrch<FloatType, SizeType>::run(
00367 FloatType const& gtol,
00368 FloatType const& stpmin,
00369 FloatType const& stpmax,
00370 SizeType n,
00371 FloatType* x,
00372 FloatType f,
00373 const FloatType* g,
00374 FloatType* s,
00375 SizeType is0,
00376 FloatType& stp,
00377 FloatType ftol,
00378 FloatType xtol,
00379 SizeType maxfev,
00380 int& info,
00381 SizeType& nfev,
00382 FloatType* wa)
00383 {
00384 if (info != -1) {
00385 infoc = 1;
00386 if ( n == 0
00387 || maxfev == 0
00388 || gtol < FloatType(0)
00389 || xtol < FloatType(0)
00390 || stpmin < FloatType(0)
00391 || stpmax < stpmin) {
00392 throw error_internal_error(__FILE__, __LINE__);
00393 }
00394 if (stp <= FloatType(0) || ftol < FloatType(0)) {
00395 throw error_internal_error(__FILE__, __LINE__);
00396 }
00397
00398
00399 dginit = FloatType(0);
00400 for (SizeType j = 0; j < n; j++) {
00401 dginit += g[j] * s[is0+j];
00402 }
00403 if (dginit >= FloatType(0)) {
00404 throw error_search_direction_not_descent();
00405 }
00406 brackt = false;
00407 stage1 = true;
00408 nfev = 0;
00409 finit = f;
00410 dgtest = ftol*dginit;
00411 width = stpmax - stpmin;
00412 width1 = FloatType(2) * width;
00413 std::copy(x, x+n, wa);
00414
00415
00416
00417
00418
00419
00420
00421 stx = FloatType(0);
00422 fx = finit;
00423 dgx = dginit;
00424 sty = FloatType(0);
00425 fy = finit;
00426 dgy = dginit;
00427 }
00428 for (;;) {
00429 if (info != -1) {
00430
00431
00432 if (brackt) {
00433 stmin = std::min(stx, sty);
00434 stmax = std::max(stx, sty);
00435 }
00436 else {
00437 stmin = stx;
00438 stmax = stp + FloatType(4) * (stp - stx);
00439 }
00440
00441 stp = std::max(stp, stpmin);
00442 stp = std::min(stp, stpmax);
00443
00444
00445 if ( (brackt && (stp <= stmin || stp >= stmax))
00446 || nfev >= maxfev - 1 || infoc == 0
00447 || (brackt && stmax - stmin <= xtol * stmax)) {
00448 stp = stx;
00449 }
00450
00451
00452
00453 for (SizeType j = 0; j < n; j++) {
00454 x[j] = wa[j] + stp * s[is0+j];
00455 }
00456 info=-1;
00457 break;
00458 }
00459 info = 0;
00460 nfev++;
00461 FloatType dg(0);
00462 for (SizeType j = 0; j < n; j++) {
00463 dg += g[j] * s[is0+j];
00464 }
00465 FloatType ftest1 = finit + stp*dgtest;
00466
00467 if ((brackt && (stp <= stmin || stp >= stmax)) || infoc == 0) {
00468 throw error_line_search_failed_rounding_errors(
00469 "Rounding errors prevent further progress."
00470 " There may not be a step which satisfies the"
00471 " sufficient decrease and curvature conditions."
00472 " Tolerances may be too small.");
00473 }
00474 if (stp == stpmax && f <= ftest1 && dg <= dgtest) {
00475 throw error_line_search_failed(
00476 "The step is at the upper bound stpmax().");
00477 }
00478 if (stp == stpmin && (f > ftest1 || dg >= dgtest)) {
00479 throw error_line_search_failed(
00480 "The step is at the lower bound stpmin().");
00481 }
00482 if (nfev >= maxfev) {
00483 throw error_line_search_failed(
00484 "Number of function evaluations has reached maxfev().");
00485 }
00486 if (brackt && stmax - stmin <= xtol * stmax) {
00487 throw error_line_search_failed(
00488 "Relative width of the interval of uncertainty"
00489 " is at most xtol().");
00490 }
00491
00492 if (f <= ftest1 && abs(dg) <= gtol * (-dginit)) {
00493 info = 1;
00494 break;
00495 }
00496
00497
00498 if ( stage1 && f <= ftest1
00499 && dg >= std::min(ftol, gtol) * dginit) {
00500 stage1 = false;
00501 }
00502
00503
00504
00505
00506
00507 if (stage1 && f <= fx && f > ftest1) {
00508
00509 FloatType fm = f - stp*dgtest;
00510 FloatType fxm = fx - stx*dgtest;
00511 FloatType fym = fy - sty*dgtest;
00512 FloatType dgm = dg - dgtest;
00513 FloatType dgxm = dgx - dgtest;
00514 FloatType dgym = dgy - dgtest;
00515
00516
00517 infoc = mcstep(stx, fxm, dgxm, sty, fym, dgym, stp, fm, dgm,
00518 brackt, stmin, stmax);
00519
00520 fx = fxm + stx*dgtest;
00521 fy = fym + sty*dgtest;
00522 dgx = dgxm + dgtest;
00523 dgy = dgym + dgtest;
00524 }
00525 else {
00526
00527
00528 infoc = mcstep(stx, fx, dgx, sty, fy, dgy, stp, f, dg,
00529 brackt, stmin, stmax);
00530 }
00531
00532
00533 if (brackt) {
00534 if (abs(sty - stx) >= FloatType(0.66) * width1) {
00535 stp = stx + FloatType(0.5) * (sty - stx);
00536 }
00537 width1 = width;
00538 width = abs(sty - stx);
00539 }
00540 }
00541 }
00542
00543 template <typename FloatType, typename SizeType>
00544 int mcsrch<FloatType, SizeType>::mcstep(
00545 FloatType& stx,
00546 FloatType& fx,
00547 FloatType& dx,
00548 FloatType& sty,
00549 FloatType& fy,
00550 FloatType& dy,
00551 FloatType& stp,
00552 FloatType fp,
00553 FloatType dp,
00554 bool& brackt,
00555 FloatType stpmin,
00556 FloatType stpmax)
00557 {
00558 bool bound;
00559 FloatType gamma, p, q, r, s, sgnd, stpc, stpf, stpq, theta;
00560 int info = 0;
00561 if ( ( brackt && (stp <= std::min(stx, sty)
00562 || stp >= std::max(stx, sty)))
00563 || dx * (stp - stx) >= FloatType(0) || stpmax < stpmin) {
00564 return 0;
00565 }
00566
00567 sgnd = dp * (dx / abs(dx));
00568 if (fp > fx) {
00569
00570
00571
00572
00573 info = 1;
00574 bound = true;
00575 theta = FloatType(3) * (fx - fp) / (stp - stx) + dx + dp;
00576 s = max3(abs(theta), abs(dx), abs(dp));
00577 gamma = s * std::sqrt(pow2(theta / s) - (dx / s) * (dp / s));
00578 if (stp < stx) gamma = - gamma;
00579 p = (gamma - dx) + theta;
00580 q = ((gamma - dx) + gamma) + dp;
00581 r = p/q;
00582 stpc = stx + r * (stp - stx);
00583 stpq = stx
00584 + ((dx / ((fx - fp) / (stp - stx) + dx)) / FloatType(2))
00585 * (stp - stx);
00586 if (abs(stpc - stx) < abs(stpq - stx)) {
00587 stpf = stpc;
00588 }
00589 else {
00590 stpf = stpc + (stpq - stpc) / FloatType(2);
00591 }
00592 brackt = true;
00593 }
00594 else if (sgnd < FloatType(0)) {
00595
00596
00597
00598
00599 info = 2;
00600 bound = false;
00601 theta = FloatType(3) * (fx - fp) / (stp - stx) + dx + dp;
00602 s = max3(abs(theta), abs(dx), abs(dp));
00603 gamma = s * std::sqrt(pow2(theta / s) - (dx / s) * (dp / s));
00604 if (stp > stx) gamma = - gamma;
00605 p = (gamma - dp) + theta;
00606 q = ((gamma - dp) + gamma) + dx;
00607 r = p/q;
00608 stpc = stp + r * (stx - stp);
00609 stpq = stp + (dp / (dp - dx)) * (stx - stp);
00610 if (abs(stpc - stp) > abs(stpq - stp)) {
00611 stpf = stpc;
00612 }
00613 else {
00614 stpf = stpq;
00615 }
00616 brackt = true;
00617 }
00618 else if (abs(dp) < abs(dx)) {
00619
00620
00621
00622
00623
00624
00625
00626
00627 info = 3;
00628 bound = true;
00629 theta = FloatType(3) * (fx - fp) / (stp - stx) + dx + dp;
00630 s = max3(abs(theta), abs(dx), abs(dp));
00631 gamma = s * std::sqrt(
00632 std::max(FloatType(0), pow2(theta / s) - (dx / s) * (dp / s)));
00633 if (stp > stx) gamma = -gamma;
00634 p = (gamma - dp) + theta;
00635 q = (gamma + (dx - dp)) + gamma;
00636 r = p/q;
00637 if (r < FloatType(0) && gamma != FloatType(0)) {
00638 stpc = stp + r * (stx - stp);
00639 }
00640 else if (stp > stx) {
00641 stpc = stpmax;
00642 }
00643 else {
00644 stpc = stpmin;
00645 }
00646 stpq = stp + (dp / (dp - dx)) * (stx - stp);
00647 if (brackt) {
00648 if (abs(stp - stpc) < abs(stp - stpq)) {
00649 stpf = stpc;
00650 }
00651 else {
00652 stpf = stpq;
00653 }
00654 }
00655 else {
00656 if (abs(stp - stpc) > abs(stp - stpq)) {
00657 stpf = stpc;
00658 }
00659 else {
00660 stpf = stpq;
00661 }
00662 }
00663 }
00664 else {
00665
00666
00667
00668
00669 info = 4;
00670 bound = false;
00671 if (brackt) {
00672 theta = FloatType(3) * (fp - fy) / (sty - stp) + dy + dp;
00673 s = max3(abs(theta), abs(dy), abs(dp));
00674 gamma = s * std::sqrt(pow2(theta / s) - (dy / s) * (dp / s));
00675 if (stp > sty) gamma = -gamma;
00676 p = (gamma - dp) + theta;
00677 q = ((gamma - dp) + gamma) + dy;
00678 r = p/q;
00679 stpc = stp + r * (sty - stp);
00680 stpf = stpc;
00681 }
00682 else if (stp > stx) {
00683 stpf = stpmax;
00684 }
00685 else {
00686 stpf = stpmin;
00687 }
00688 }
00689
00690
00691 if (fp > fx) {
00692 sty = stp;
00693 fy = fp;
00694 dy = dp;
00695 }
00696 else {
00697 if (sgnd < FloatType(0)) {
00698 sty = stx;
00699 fy = fx;
00700 dy = dx;
00701 }
00702 stx = stp;
00703 fx = fp;
00704 dx = dp;
00705 }
00706
00707 stpf = std::min(stpmax, stpf);
00708 stpf = std::max(stpmin, stpf);
00709 stp = stpf;
00710 if (brackt && bound) {
00711 if (sty > stx) {
00712 stp = std::min(stx + FloatType(0.66) * (sty - stx), stp);
00713 }
00714 else {
00715 stp = std::max(stx + FloatType(0.66) * (sty - stx), stp);
00716 }
00717 }
00718 return info;
00719 }
00720
00721
00722
00723
00724
00725 template <typename FloatType, typename SizeType>
00726 void daxpy(
00727 SizeType n,
00728 FloatType da,
00729 const FloatType* dx,
00730 SizeType ix0,
00731 SizeType incx,
00732 FloatType* dy,
00733 SizeType iy0,
00734 SizeType incy)
00735 {
00736 SizeType i, ix, iy, m;
00737 if (n == 0) return;
00738 if (da == FloatType(0)) return;
00739 if (!(incx == 1 && incy == 1)) {
00740 ix = 0;
00741 iy = 0;
00742 for (i = 0; i < n; i++) {
00743 dy[iy0+iy] += da * dx[ix0+ix];
00744 ix += incx;
00745 iy += incy;
00746 }
00747 return;
00748 }
00749 m = n % 4;
00750 for (i = 0; i < m; i++) {
00751 dy[iy0+i] += da * dx[ix0+i];
00752 }
00753 for (; i < n;) {
00754 dy[iy0+i] += da * dx[ix0+i]; i++;
00755 dy[iy0+i] += da * dx[ix0+i]; i++;
00756 dy[iy0+i] += da * dx[ix0+i]; i++;
00757 dy[iy0+i] += da * dx[ix0+i]; i++;
00758 }
00759 }
00760
00761 template <typename FloatType, typename SizeType>
00762 inline
00763 void daxpy(
00764 SizeType n,
00765 FloatType da,
00766 const FloatType* dx,
00767 SizeType ix0,
00768 FloatType* dy)
00769 {
00770 daxpy(n, da, dx, ix0, SizeType(1), dy, SizeType(0), SizeType(1));
00771 }
00772
00773
00774
00775
00776
00777 template <typename FloatType, typename SizeType>
00778 FloatType ddot(
00779 SizeType n,
00780 const FloatType* dx,
00781 SizeType ix0,
00782 SizeType incx,
00783 const FloatType* dy,
00784 SizeType iy0,
00785 SizeType incy)
00786 {
00787 SizeType i, ix, iy, m;
00788 FloatType dtemp(0);
00789 if (n == 0) return FloatType(0);
00790 if (!(incx == 1 && incy == 1)) {
00791 ix = 0;
00792 iy = 0;
00793 for (i = 0; i < n; i++) {
00794 dtemp += dx[ix0+ix] * dy[iy0+iy];
00795 ix += incx;
00796 iy += incy;
00797 }
00798 return dtemp;
00799 }
00800 m = n % 5;
00801 for (i = 0; i < m; i++) {
00802 dtemp += dx[ix0+i] * dy[iy0+i];
00803 }
00804 for (; i < n;) {
00805 dtemp += dx[ix0+i] * dy[iy0+i]; i++;
00806 dtemp += dx[ix0+i] * dy[iy0+i]; i++;
00807 dtemp += dx[ix0+i] * dy[iy0+i]; i++;
00808 dtemp += dx[ix0+i] * dy[iy0+i]; i++;
00809 dtemp += dx[ix0+i] * dy[iy0+i]; i++;
00810 }
00811 return dtemp;
00812 }
00813
00814 template <typename FloatType, typename SizeType>
00815 inline
00816 FloatType ddot(
00817 SizeType n,
00818 const FloatType* dx,
00819 const FloatType* dy)
00820 {
00821 return ddot(
00822 n, dx, SizeType(0), SizeType(1), dy, SizeType(0), SizeType(1));
00823 }
00824
00825 }
00826
00828
00879 template <typename FloatType, typename SizeType = std::size_t>
00880 class minimizer
00881 {
00882 public:
00884 minimizer()
00885 : n_(0), m_(0), maxfev_(0),
00886 gtol_(0), xtol_(0),
00887 stpmin_(0), stpmax_(0),
00888 ispt(0), iypt(0)
00889 {}
00890
00892
00935 explicit
00936 minimizer(
00937 SizeType n,
00938 SizeType m = 5,
00939 SizeType maxfev = 20,
00940 FloatType gtol = FloatType(0.9),
00941 FloatType xtol = FloatType(1.e-16),
00942 FloatType stpmin = FloatType(1.e-20),
00943 FloatType stpmax = FloatType(1.e20))
00944 : n_(n), m_(m), maxfev_(maxfev),
00945 gtol_(gtol), xtol_(xtol),
00946 stpmin_(stpmin), stpmax_(stpmax),
00947 iflag_(0), requests_f_and_g_(false), requests_diag_(false),
00948 iter_(0), nfun_(0), stp_(0),
00949 stp1(0), ftol(0.0001), ys(0), point(0), npt(0),
00950 ispt(n+2*m), iypt((n+2*m)+n*m),
00951 info(0), bound(0), nfev(0)
00952 {
00953 if (n_ == 0) {
00954 throw error_improper_input_parameter("n = 0.");
00955 }
00956 if (m_ == 0) {
00957 throw error_improper_input_parameter("m = 0.");
00958 }
00959 if (maxfev_ == 0) {
00960 throw error_improper_input_parameter("maxfev = 0.");
00961 }
00962 if (gtol_ <= FloatType(1.e-4)) {
00963 throw error_improper_input_parameter("gtol <= 1.e-4.");
00964 }
00965 if (xtol_ < FloatType(0)) {
00966 throw error_improper_input_parameter("xtol < 0.");
00967 }
00968 if (stpmin_ < FloatType(0)) {
00969 throw error_improper_input_parameter("stpmin < 0.");
00970 }
00971 if (stpmax_ < stpmin) {
00972 throw error_improper_input_parameter("stpmax < stpmin");
00973 }
00974 w_.resize(n_*(2*m_+1)+2*m_);
00975 scratch_array_.resize(n_);
00976 }
00977
00979 SizeType n() const { return n_; }
00980
00982 SizeType m() const { return m_; }
00983
00987 SizeType maxfev() const { return maxfev_; }
00988
00992 FloatType gtol() const { return gtol_; }
00993
00995 FloatType xtol() const { return xtol_; }
00996
01000 FloatType stpmin() const { return stpmin_; }
01001
01005 FloatType stpmax() const { return stpmax_; }
01006
01008
01017 bool requests_f_and_g() const { return requests_f_and_g_; }
01018
01020
01028 bool requests_diag() const { return requests_diag_; }
01029
01031
01036 SizeType iter() const { return iter_; }
01037
01039
01045 SizeType nfun() const { return nfun_; }
01046
01048 FloatType euclidean_norm(const FloatType* a) const {
01049 return std::sqrt(detail::ddot(n_, a, a));
01050 }
01051
01053 FloatType stp() const { return stp_; }
01054
01056
01084 bool run(
01085 FloatType* x,
01086 FloatType f,
01087 const FloatType* g)
01088 {
01089 return generic_run(x, f, g, false, 0);
01090 }
01091
01093
01108 bool run(
01109 FloatType* x,
01110 FloatType f,
01111 const FloatType* g,
01112 const FloatType* diag)
01113 {
01114 return generic_run(x, f, g, true, diag);
01115 }
01116
01117 protected:
01118 static void throw_diagonal_element_not_positive(SizeType i) {
01119 throw error_improper_input_data(
01120 "The " + error::itoa(i) + ". diagonal element of the"
01121 " inverse Hessian approximation is not positive.");
01122 }
01123
01124 bool generic_run(
01125 FloatType* x,
01126 FloatType f,
01127 const FloatType* g,
01128 bool diagco,
01129 const FloatType* diag);
01130
01131 detail::mcsrch<FloatType, SizeType> mcsrch_instance;
01132 const SizeType n_;
01133 const SizeType m_;
01134 const SizeType maxfev_;
01135 const FloatType gtol_;
01136 const FloatType xtol_;
01137 const FloatType stpmin_;
01138 const FloatType stpmax_;
01139 int iflag_;
01140 bool requests_f_and_g_;
01141 bool requests_diag_;
01142 SizeType iter_;
01143 SizeType nfun_;
01144 FloatType stp_;
01145 FloatType stp1;
01146 FloatType ftol;
01147 FloatType ys;
01148 SizeType point;
01149 SizeType npt;
01150 const SizeType ispt;
01151 const SizeType iypt;
01152 int info;
01153 SizeType bound;
01154 SizeType nfev;
01155 std::vector<FloatType> w_;
01156 std::vector<FloatType> scratch_array_;
01157 };
01158
01159 template <typename FloatType, typename SizeType>
01160 bool minimizer<FloatType, SizeType>::generic_run(
01161 FloatType* x,
01162 FloatType f,
01163 const FloatType* g,
01164 bool diagco,
01165 const FloatType* diag)
01166 {
01167 bool execute_entire_while_loop = false;
01168 if (!(requests_f_and_g_ || requests_diag_)) {
01169 execute_entire_while_loop = true;
01170 }
01171 requests_f_and_g_ = false;
01172 requests_diag_ = false;
01173 FloatType* w = &(*(w_.begin()));
01174 if (iflag_ == 0) {
01175 nfun_ = 1;
01176 if (diagco) {
01177 for (SizeType i = 0; i < n_; i++) {
01178 if (diag[i] <= FloatType(0)) {
01179 throw_diagonal_element_not_positive(i);
01180 }
01181 }
01182 }
01183 else {
01184 std::fill_n(scratch_array_.begin(), n_, FloatType(1));
01185 diag = &(*(scratch_array_.begin()));
01186 }
01187 for (SizeType i = 0; i < n_; i++) {
01188 w[ispt + i] = -g[i] * diag[i];
01189 }
01190 FloatType gnorm = std::sqrt(detail::ddot(n_, g, g));
01191 if (gnorm == FloatType(0)) return false;
01192 stp1 = FloatType(1) / gnorm;
01193 execute_entire_while_loop = true;
01194 }
01195 if (execute_entire_while_loop) {
01196 bound = iter_;
01197 iter_++;
01198 info = 0;
01199 if (iter_ != 1) {
01200 if (iter_ > m_) bound = m_;
01201 ys = detail::ddot(
01202 n_, w, iypt + npt, SizeType(1), w, ispt + npt, SizeType(1));
01203 if (!diagco) {
01204 FloatType yy = detail::ddot(
01205 n_, w, iypt + npt, SizeType(1), w, iypt + npt, SizeType(1));
01206 std::fill_n(scratch_array_.begin(), n_, ys / yy);
01207 diag = &(*(scratch_array_.begin()));
01208 }
01209 else {
01210 iflag_ = 2;
01211 requests_diag_ = true;
01212 return true;
01213 }
01214 }
01215 }
01216 if (execute_entire_while_loop || iflag_ == 2) {
01217 if (iter_ != 1) {
01218 if (diag == 0) {
01219 throw error_internal_error(__FILE__, __LINE__);
01220 }
01221 if (diagco) {
01222 for (SizeType i = 0; i < n_; i++) {
01223 if (diag[i] <= FloatType(0)) {
01224 throw_diagonal_element_not_positive(i);
01225 }
01226 }
01227 }
01228 SizeType cp = point;
01229 if (point == 0) cp = m_;
01230 w[n_ + cp -1] = 1 / ys;
01231 SizeType i;
01232 for (i = 0; i < n_; i++) {
01233 w[i] = -g[i];
01234 }
01235 cp = point;
01236 for (i = 0; i < bound; i++) {
01237 if (cp == 0) cp = m_;
01238 cp--;
01239 FloatType sq = detail::ddot(
01240 n_, w, ispt + cp * n_, SizeType(1), w, SizeType(0), SizeType(1));
01241 SizeType inmc=n_+m_+cp;
01242 SizeType iycn=iypt+cp*n_;
01243 w[inmc] = w[n_ + cp] * sq;
01244 detail::daxpy(n_, -w[inmc], w, iycn, w);
01245 }
01246 for (i = 0; i < n_; i++) {
01247 w[i] *= diag[i];
01248 }
01249 for (i = 0; i < bound; i++) {
01250 FloatType yr = detail::ddot(
01251 n_, w, iypt + cp * n_, SizeType(1), w, SizeType(0), SizeType(1));
01252 FloatType beta = w[n_ + cp] * yr;
01253 SizeType inmc=n_+m_+cp;
01254 beta = w[inmc] - beta;
01255 SizeType iscn=ispt+cp*n_;
01256 detail::daxpy(n_, beta, w, iscn, w);
01257 cp++;
01258 if (cp == m_) cp = 0;
01259 }
01260 std::copy(w, w+n_, w+(ispt + point * n_));
01261 }
01262 stp_ = FloatType(1);
01263 if (iter_ == 1) stp_ = stp1;
01264 std::copy(g, g+n_, w);
01265 }
01266 mcsrch_instance.run(
01267 gtol_, stpmin_, stpmax_, n_, x, f, g, w, ispt + point * n_,
01268 stp_, ftol, xtol_, maxfev_, info, nfev, &(*(scratch_array_.begin())));
01269 if (info == -1) {
01270 iflag_ = 1;
01271 requests_f_and_g_ = true;
01272 return true;
01273 }
01274 if (info != 1) {
01275 throw error_internal_error(__FILE__, __LINE__);
01276 }
01277 nfun_ += nfev;
01278 npt = point*n_;
01279 for (SizeType i = 0; i < n_; i++) {
01280 w[ispt + npt + i] = stp_ * w[ispt + npt + i];
01281 w[iypt + npt + i] = g[i] - w[i];
01282 }
01283 point++;
01284 if (point == m_) point = 0;
01285 return false;
01286 }
01287
01289
01296 template <typename FloatType, typename SizeType = std::size_t>
01297 class traditional_convergence_test
01298 {
01299 public:
01301 traditional_convergence_test()
01302 : n_(0), eps_(0)
01303 {}
01304
01306
01312 explicit
01313 traditional_convergence_test(
01314 SizeType n,
01315 FloatType eps = FloatType(1.e-5))
01316 : n_(n), eps_(eps)
01317 {
01318 if (n_ == 0) {
01319 throw error_improper_input_parameter("n = 0.");
01320 }
01321 if (eps_ < FloatType(0)) {
01322 throw error_improper_input_parameter("eps < 0.");
01323 }
01324 }
01325
01327 SizeType n() const { return n_; }
01328
01332 FloatType eps() const { return eps_; }
01333
01335
01346 bool
01347 operator()(const FloatType* x, const FloatType* g) const
01348 {
01349 FloatType xnorm = std::sqrt(detail::ddot(n_, x, x));
01350 FloatType gnorm = std::sqrt(detail::ddot(n_, g, g));
01351 if (gnorm <= eps_ * std::max(FloatType(1), xnorm)) return true;
01352 return false;
01353 }
01354 protected:
01355 const SizeType n_;
01356 const FloatType eps_;
01357 };
01358
01359 }}
01360
01361 #endif // SCITBX_LBFGS_H