00001
00002 #ifndef __LEMGA_AGGREGATING_BOOSTING_H__
00003 #define __LEMGA_AGGREGATING_BOOSTING_H__
00004
00011 #include <numeric>
00012 #include <utility>
00013 #include "aggregating.h"
00014 #include "cost.h"
00015
00016 #define BOOSTING_OUTPUT_CACHE 1
00017
00018 namespace lemga {
00019
00021 struct _boost_gd;
00022
00089 class Boosting : public Aggregating {
00090 protected:
00091 std::vector<REAL> lm_wgt;
00092 bool convex;
00093
00094
00095 bool grad_desc_view;
00096 REAL min_cst, min_err;
00097
00098 public:
00100 const cost::Cost& cost_functor;
00101
00102 explicit Boosting (bool cvx = false, const cost::Cost& = cost::_cost);
00103 Boosting (const Aggregating&);
00104 explicit Boosting (std::istream& is): cost_functor(cost::_cost)
00105 { is >> *this; }
00106
00107 virtual const id_t& id () const;
00108 virtual Boosting* create () const { return new Boosting(); }
00109 virtual Boosting* clone () const { return new Boosting(*this); }
00110
00111 bool is_convex () const { return convex; }
00112 REAL model_weight (UINT n) const { return lm_wgt[n]; }
00113 void use_gradient_descent (bool gd = true) { grad_desc_view = gd; }
00114 void set_min_cost (REAL mincst) { min_cst = mincst; }
00115 void set_min_error (REAL minerr) { min_err = minerr; }
00116 virtual REAL margin_norm () const;
00117 virtual REAL margin_of (const Input&, const Output&) const;
00118 virtual REAL margin (UINT) const;
00119
00120 virtual bool support_weighted_data () const { return true; }
00121 virtual void train ();
00122 virtual void reset ();
00123 virtual Output operator() (const Input&) const;
00124 virtual Output get_output (UINT) const;
00125
00126 #if BOOSTING_OUTPUT_CACHE
00127 virtual void set_train_data (const pDataSet&, const pDataWgt& = 0);
00128
00129 private:
00130 mutable std::vector<UINT> cache_n;
00131 mutable std::vector<Output> cache_y;
00132 protected:
00133 inline void clear_cache (UINT idx) const {
00134 assert(idx < n_samples && cache_n.size() == n_samples);
00135 cache_n[idx] = 0;
00136 cache_y[idx].clear();
00137 }
00138 inline void clear_cache () const {
00139 cache_n.clear(); cache_y.clear();
00140 cache_n.resize(n_samples, 0);
00141 cache_y.resize(n_samples);
00142 }
00143 #endif
00144
00145 protected:
00146 REAL model_weight_sum () const {
00147 return std::accumulate
00148 (lm_wgt.begin(), lm_wgt.begin()+n_in_agg, REAL(0));
00149 }
00150
00152 pLearnModel train_with_smpwgt (const pDataWgt&) const;
00154
00156
00157
00163 REAL assign_weight (const DataWgt& sw, const LearnModel& l) {
00164 assert(n_samples == sw.size());
00165 return convex? convex_weight(sw, l) : linear_weight(sw, l);
00166 }
00168
00173 pDataWgt update_smpwgt (const DataWgt& sw, const LearnModel& l) {
00174 assert(n_in_agg > 0 && lm[n_in_agg-1] == &l);
00175 assert(n_samples == sw.size());
00176 DataWgt* pdw = new DataWgt(sw);
00177 convex? convex_smpwgt(*pdw) : linear_smpwgt(*pdw);
00178 return pdw;
00179 }
00180
00182 virtual REAL convex_weight (const DataWgt&, const LearnModel&);
00183 virtual REAL linear_weight (const DataWgt&, const LearnModel&);
00184 virtual void convex_smpwgt (DataWgt&);
00185 virtual void linear_smpwgt (DataWgt&);
00187
00188 protected:
00189 virtual bool serialize (std::ostream&, ver_list&) const;
00190 virtual bool unserialize (std::istream&, ver_list&, const id_t& = NIL_ID);
00191
00193 public:
00194 friend struct _boost_gd;
00195 REAL cost () const;
00196
00197 protected:
00198 virtual void train_gd ();
00199 pDataWgt sample_weight () const;
00200
00201 public:
00203 class BoostWgt {
00204 std::vector<pLearnModel> lm;
00205 std::vector<REAL> lm_wgt;
00206
00207 public:
00208 BoostWgt () {}
00209 BoostWgt (const std::vector<pLearnModel>& _lm,
00210 const std::vector<REAL>& _wgt)
00211 : lm(_lm), lm_wgt(_wgt) { assert(lm.size() == lm_wgt.size()); }
00212
00213 UINT size () const { return lm.size(); }
00214 const std::vector<pLearnModel>& models () const { return lm; }
00215 const std::vector<REAL>& weights () const { return lm_wgt; }
00216 void clear () { lm.clear(); lm_wgt.clear(); }
00217
00218 BoostWgt& operator+= (const BoostWgt&);
00219 BoostWgt& operator*= (REAL);
00220 BoostWgt operator- () const;
00221 #ifndef NDEBUG
00222 bool operator== (const BoostWgt& w) const {
00223 return (lm == w.lm && lm_wgt == w.lm_wgt);
00224 }
00225 #endif
00226 };
00228 };
00229
00230 struct _boost_gd {
00231 Boosting* b;
00232 UINT max_step;
00233 explicit _boost_gd (Boosting* pb) : b(pb)
00234 { max_step = b->max_n_model - b->size(); }
00235
00236 REAL cost () const { return b->cost(); }
00237
00238 Boosting::BoostWgt weight () const {
00239 return Boosting::BoostWgt(b->lm, b->lm_wgt);
00240 }
00241
00242 void set_weight (const Boosting::BoostWgt& bw) const {
00243 #if BOOSTING_OUTPUT_CACHE
00244 b->clear_cache();
00245 #endif
00246 b->lm = bw.models(); b->lm_wgt = bw.weights();
00247 b->n_in_agg = b->lm.size();
00248 assert(b->lm.size() == b->lm_wgt.size());
00249 for (UINT i = 0; i < b->lm.size(); ++i)
00250 b->set_dimensions(*(b->lm[i]));
00251 }
00252
00253 Boosting::BoostWgt gradient () const {
00254 std::vector<pLearnModel> lm = b->lm;
00255 std::vector<REAL> wgt(lm.size(), 0);
00256
00257 lm.push_back(b->train_with_smpwgt(b->sample_weight()));
00258 wgt.push_back(-1);
00259 return Boosting::BoostWgt(lm, wgt);
00260 }
00261
00262 bool stop_opt (UINT step, REAL cst) const {
00263 return (step >= max_step || cst < b->min_cst);
00264 }
00265 };
00266
00267 namespace op {
00268
00269 template <typename R>
00270 R inner_product (const Boosting::BoostWgt& w1, const Boosting::BoostWgt& w2) {
00271 #ifndef NDEBUG
00272 std::vector<REAL> w1t(w1.size()); w1t.back() = -1;
00273 assert(w1.weights() == w1t);
00274 std::vector<REAL> w2t(w2.size()); w2t.back() = -1;
00275 assert(w2.weights() == w2t);
00276 #endif
00277 LearnModel& g1 = *w1.models().back();
00278 LearnModel& g2 = *w2.models().back();
00279
00280 UINT n = g1.train_data()->size();
00281 assert(g1.train_data() == g2.train_data());
00282 R sum = 0;
00283 for (UINT i = 0; i < n; ++i)
00284 sum += g1.get_output(i)[0] * g2.get_output(i)[0];
00285 return sum / n;
00286 }
00287
00288 }
00289 }
00290
00291 #ifdef __BOOSTING_H__
00292 #warning "This header file may conflict with another `boosting.h' file."
00293 #endif
00294 #define __BOOSTING_H__
00295 #endif