00001
00002 #ifndef __LEMGA_OPTIMIZE_H__
00003 #define __LEMGA_OPTIMIZE_H__
00004
00011 #include <assert.h>
00012 #include <iostream>
00013 #include <utility>
00014 #include "vectorop.h"
00015
00016 namespace lemga {
00017
00018 using namespace op;
00019
00052 template <class Dir, class Step>
00053 struct _search {
00054 typedef Dir direction_type;
00055 typedef Step step_length_type;
00056
00058 void initialize () {}
00060 const Dir& direction () { return dir; }
00062 std::pair<bool,Step> step_length (const Dir& d) {
00063 return std::make_pair(false, 0); }
00065 void update_weight (const Dir& d, const Step& s) {}
00067 bool satisfied () { return true; }
00068 protected:
00069 Dir dir;
00070 };
00071
00073 template <class SEARCH>
00074 void iterative_optimize (SEARCH s) {
00075 s.initialize();
00076 while (!s.satisfied()) {
00077 const typename SEARCH::direction_type&
00078 pd = s.direction();
00079 const std::pair<bool, typename SEARCH::step_length_type>
00080 stp = s.step_length(pd);
00081
00082 if (!stp.first) break;
00083 s.update_weight(pd, stp.second);
00084 }
00085 }
00086
00092 template <class LM, class Dir, class Step>
00093 struct _gradient_descent : public _search<Dir,Step> {
00094 LM* plm;
00095 Step learning_rate;
00096
00097 _gradient_descent (LM* lm, const Step& lr)
00098 : _search<Dir,Step>(), plm(lm), learning_rate(lr) {}
00099
00100 void initialize () { stp_cnt = 0; w = plm->weight(); }
00101
00102 const Dir& direction () {
00103 using namespace op;
00104 return (this->dir = -plm->gradient());
00105 }
00106
00107 std::pair<bool,Step> step_length (const Dir&) {
00108 return std::make_pair(true, learning_rate);
00109 }
00110
00111 void update_weight (Dir d, const Step& s) {
00112 using namespace op;
00113 w += (d *= s);
00114 ++stp_cnt; plm->set_weight(w);
00115 }
00116
00117 bool satisfied () { return plm->stop_opt(stp_cnt, plm->cost()); }
00118
00119 protected:
00120 Dir w;
00121 unsigned int stp_cnt;
00122 };
00123
00139 template <class LM, class Dir, class Step>
00140 struct _gd_weightdecay : public _gradient_descent<LM,Dir,Step> {
00141 Step decay;
00142
00143 _gd_weightdecay (LM* lm, const Step& lr, const Step& dcy)
00144 : _gradient_descent<LM,Dir,Step>(lm,lr), decay(dcy) {}
00145
00146 void update_weight (Dir d, const Step& s) {
00147 using namespace op;
00148 assert(s*decay < 1);
00149 this->w *= (1 - s*decay);
00150 this->w += (d *= s);
00151 ++this->stp_cnt; this->plm->set_weight(this->w);
00152 }
00153 };
00154
00168 template <class LM, class Dir, class Step>
00169 struct _gd_momentum : public _gradient_descent<LM,Dir,Step> {
00170 Step momentum;
00171
00172 _gd_momentum (LM* lm, const Step& lr, const Step& m)
00173 : _gradient_descent<LM,Dir,Step>(lm,lr), momentum(m) {}
00174
00175 const Dir& direction () {
00176 assert(momentum >= 0 && momentum < 1);
00177 using namespace op;
00178 if (this->stp_cnt > 0) {
00179 this->dir *= momentum;
00180 this->dir += -this->plm->gradient();
00181 }
00182 else this->dir = -this->plm->gradient();
00183 return this->dir;
00184 }
00185 };
00186
00187 template <class LM, class Dir, class Step, class Cost>
00188 struct _gd_adaptive : public _gradient_descent<LM,Dir,Step> {
00189 using _gradient_descent<LM,Dir,Step>::plm;
00190 using _gradient_descent<LM,Dir,Step>::learning_rate;
00191 using _gradient_descent<LM,Dir,Step>::w;
00192 using _gradient_descent<LM,Dir,Step>::stp_cnt;
00193 Step alpha, beta;
00194
00195 _gd_adaptive (LM* lm, const Step& lr, const Step& a, const Step& b)
00196 : _gradient_descent<LM,Dir,Step>(lm,lr), alpha(a), beta(b) {}
00197
00198 void initialize () {
00199 _gradient_descent<LM,Dir,Step>::initialize();
00200 old_cost = plm->cost();
00201 }
00202
00203 std::pair<bool,Step> step_length (Dir d) {
00204 assert(alpha >= 1 && beta < 1);
00205 using namespace op;
00206
00207 Step lr = learning_rate;
00208
00209 d *= learning_rate;
00210 Dir wd = w;
00211 plm->set_weight(wd += d);
00212 Cost c = plm->cost();
00213 if (c < old_cost)
00214 learning_rate *= alpha;
00215 else {
00216 do {
00217 learning_rate *= beta;
00218 d *= beta; wd = w;
00219 plm->set_weight(wd += d);
00220 c = plm->cost();
00221 } while (!(c < old_cost) && learning_rate > 1e-6);
00222 lr = learning_rate;
00223 }
00224
00225 const bool cont = (c < old_cost);
00226 if (cont) old_cost = c;
00227 else plm->set_weight(w);
00228 return std::make_pair(cont, lr);
00229 }
00230
00231 void update_weight (Dir d, const Step& s) {
00232 ++stp_cnt;
00233 using namespace op;
00234 w += (d *= s);
00235
00236 }
00237
00238 bool satisfied () { return plm->stop_opt(stp_cnt, old_cost); }
00239
00240 protected:
00241 Cost old_cost;
00242 };
00243
00244 namespace details {
00245
00246 template <class LM, class Dir, class Step, class Cost>
00247 Step line_search (LM& lm, const Dir& w, Cost& cst3,
00248 const Dir& dir, Step step) {
00249 using namespace op;
00250 assert(w == lm.weight());
00251 cst3 = lm.cost();
00252 Step stp3 = 0;
00253
00254 Dir d = dir, wd = w; d *= step;
00255 lm.set_weight(wd += d); Cost cst5 = lm.cost();
00256 while (cst5 > cst3 && step > 2e-7) {
00257 std::cout << '-';
00258 step *= 0.5; d *= 0.5; wd = w;
00259 lm.set_weight(wd += d); cst5 = lm.cost();
00260 }
00261
00262 if (cst5 > cst3) {
00263 std::cerr << "\tWarning: not a descending direction\n";
00264 lm.set_weight(w);
00265 return 0;
00266 }
00267
00268 Step stp1, stp5 = step;
00269 do {
00270 std::cout << '*';
00271 step += step;
00272 stp1 = stp3;
00273 stp3 = stp5; cst3 = cst5;
00274 stp5 += step;
00275 d = dir; d *= stp5; wd = w;
00276 lm.set_weight(wd += d); cst5 = lm.cost();
00277 } while (cst5 < cst3);
00278
00279 while (stp3 > stp1*1.01 || stp5 > stp3*1.01) {
00280 std::cout << '.';
00281 Step stp2 = (stp1 + stp3) / 2;
00282 Step stp4 = (stp3 + stp5) / 2;
00283 d = dir; d *= stp2; wd = w;
00284 lm.set_weight(wd += d); Cost cst2 = lm.cost();
00285 d = dir; d *= stp4; wd = w;
00286 lm.set_weight(wd += d); Cost cst4 = lm.cost();
00287
00288 if (cst4 < cst2 && cst4 < cst3) {
00289 stp1 = stp3;
00290 stp3 = stp4; cst3 = cst4;
00291 }
00292 else if (cst2 < cst3) {
00293 stp5 = stp3;
00294 stp3 = stp2; cst3 = cst2;
00295 }
00296 else {
00297 stp1 = stp2;
00298 stp5 = stp4;
00299 }
00300 }
00301 std::cout << "\tcost = " << cst3 << ", step = " << stp3 << '\n';
00302 return stp3;
00303 }
00304
00305 }
00306
00307 template <class LM, class Dir, class Step, class Cost>
00308 struct _line_search : public _gradient_descent<LM,Dir,Step> {
00309 using _gradient_descent<LM,Dir,Step>::plm;
00310 using _gradient_descent<LM,Dir,Step>::learning_rate;
00311 using _gradient_descent<LM,Dir,Step>::w;
00312 using _gradient_descent<LM,Dir,Step>::stp_cnt;
00313
00314 _line_search (LM* lm, const Step& lr)
00315 : _gradient_descent<LM,Dir,Step>(lm,lr) {}
00316
00317 void initialize () {
00318 _gradient_descent<LM,Dir,Step>::initialize();
00319 cost_w = plm->cost();
00320 }
00321
00322 std::pair<bool,Step> step_length (const Dir& d) {
00323 const Step stp =
00324 details::line_search(*plm, w, cost_w, d, learning_rate);
00325 return std::make_pair((stp>0), stp);
00326 }
00327
00328 bool satisfied () { return plm->stop_opt(stp_cnt, cost_w); }
00329
00330 protected:
00331 Cost cost_w;
00332 };
00333
00334 template <class LM, class Dir, class Step, class Cost>
00335 struct _conjugate_gradient : public _line_search<LM,Dir,Step,Cost> {
00336 using _gradient_descent<LM,Dir,Step>::plm;
00337 using _gradient_descent<LM,Dir,Step>::w;
00338 using _gradient_descent<LM,Dir,Step>::stp_cnt;
00339 using _search<Dir,Step>::dir;
00340
00341 _conjugate_gradient (LM* lm, const Step& lr)
00342 : _line_search<LM,Dir,Step,Cost>(lm,lr) {}
00343
00344 const Dir& direction () {
00345
00346 const Dir g = plm->gradient();
00347 const Step g_norm = op::inner_product<Step>(g, g);
00348
00349 using namespace op;
00350
00351 if (stp_cnt == 0)
00352 dir = -g;
00353 else {
00354 const Step g_dot_old = op::inner_product<Step>(g, g_old);
00355 assert(g_norm_old > 0);
00356 Step beta = (g_norm - g_dot_old) / g_norm_old;
00357 if (beta < 0) beta = 0;
00358
00359 dir *= beta;
00360 dir += -g;
00361 }
00362
00363 g_old = g;
00364 g_norm_old = g_norm;
00365
00366 return dir;
00367 }
00368
00369 private:
00370 Dir g_old;
00371 Step g_norm_old;
00372 };
00373
00374 }
00375
00376 #ifdef __OPTIMIZE_H__
00377 #warning "This header file may conflict with another `optimize.h' file."
00378 #endif
00379 #define __OPTIMIZE_H__
00380 #endif