00001
00002 #ifndef __LEMGA_AGGREGATING_BOOSTING_H__
00003 #define __LEMGA_AGGREGATING_BOOSTING_H__
00004
00011 #include <numeric>
00012 #include <utility>
00013 #include "aggregating.h"
00014 #include "cost.h"
00015
00016 #define BOOSTING_OUTPUT_CACHE 1
00017
00018 namespace lemga {
00019
00021 struct _boost_gd;
00022
00089 class Boosting : public Aggregating {
00090 protected:
00091 std::vector<REAL> lm_wgt;
00092 bool convex;
00093
00094
00095 bool grad_desc_view;
00096 REAL min_cst, min_err;
00097
00098 public:
00100 const cost::Cost& cost_functor;
00101
00102 explicit Boosting (bool cvx = false, const cost::Cost& = cost::_cost);
00103 explicit Boosting (std::istream& is): cost_functor(cost::_cost)
00104 { is >> *this; }
00105
00106 virtual const id_t& id () const;
00107 virtual Boosting* create () const { return new Boosting(); }
00108 virtual Boosting* clone () const { return new Boosting(*this); }
00109
00110 bool is_convex () const { return convex; }
00111 REAL model_weight (UINT n) const { return lm_wgt[n]; }
00112 void use_gradient_descent (bool gd = true) { grad_desc_view = gd; }
00113 void set_min_cost (REAL mincst) { min_cst = mincst; }
00114 void set_min_error (REAL minerr) { min_err = minerr; }
00115 virtual REAL margin_norm () const;
00116 virtual REAL margin_of (const Input&, const Output&) const;
00117 virtual REAL margin (UINT) const;
00118
00119 virtual bool support_weighted_data () const { return true; }
00120 virtual void initialize ();
00121 virtual REAL train ();
00122 virtual Output operator() (const Input&) const;
00123 virtual Output get_output (UINT) const;
00124
00125 #if BOOSTING_OUTPUT_CACHE
00126 virtual void set_train_data (const pDataSet&, const pDataWgt& = 0);
00127
00128 private:
00129 mutable std::vector<UINT> cache_n;
00130 mutable std::vector<Output> cache_y;
00131 protected:
00132 inline void clear_cache (UINT idx) const {
00133 assert(idx < n_samples && cache_n.size() == n_samples);
00134 cache_n[idx] = 0;
00135 Output y(_n_out, 0);
00136 cache_y[idx].swap(y);
00137 }
00138 inline void clear_cache () const {
00139 cache_n.resize(n_samples);
00140 cache_y.resize(n_samples);
00141 for (UINT i = 0; i < n_samples; ++i)
00142 clear_cache(i);
00143 }
00144 #endif
00145
00146 protected:
00147 REAL model_weight_sum () const {
00148 return std::accumulate
00149 (lm_wgt.begin(), lm_wgt.begin()+n_in_agg, REAL(0));
00150 }
00151
00153 pLearnModel train_with_smpwgt (const pDataWgt&) const;
00155
00157
00158
00164 REAL assign_weight (const DataWgt& sw, const LearnModel& l) {
00165 assert(n_samples == sw.size());
00166 return convex? convex_weight(sw, l) : linear_weight(sw, l);
00167 }
00169
00174 pDataWgt update_smpwgt (const DataWgt& sw, const LearnModel& l) {
00175 assert(n_in_agg > 0 && lm[n_in_agg-1] == &l);
00176 assert(n_samples == sw.size());
00177 DataWgt* pdw = new DataWgt(sw);
00178 convex? convex_smpwgt(*pdw) : linear_smpwgt(*pdw);
00179 return pdw;
00180 }
00181
00183 virtual REAL convex_weight (const DataWgt&, const LearnModel&);
00184 virtual REAL linear_weight (const DataWgt&, const LearnModel&);
00185 virtual void convex_smpwgt (DataWgt&);
00186 virtual void linear_smpwgt (DataWgt&);
00188
00189 protected:
00190 virtual bool serialize (std::ostream&, ver_list&) const;
00191 virtual bool unserialize (std::istream&, ver_list&,
00192 const id_t& = empty_id);
00193
00195 public:
00196 friend struct _boost_gd;
00197 REAL cost () const;
00198
00199 protected:
00200 virtual REAL train_gd ();
00201 pDataWgt sample_weight () const;
00202
00203 public:
00205 class BoostWgt {
00206 std::vector<pLearnModel> lm;
00207 std::vector<REAL> lm_wgt;
00208
00209 public:
00210 BoostWgt () {}
00211 BoostWgt (const std::vector<pLearnModel>& _lm,
00212 const std::vector<REAL>& _wgt)
00213 : lm(_lm), lm_wgt(_wgt) { assert(lm.size() == lm_wgt.size()); }
00214
00215 UINT size () const { return lm.size(); }
00216 const std::vector<pLearnModel>& models () const { return lm; }
00217 const std::vector<REAL>& weights () const { return lm_wgt; }
00218 void clear () { lm.clear(); lm_wgt.clear(); }
00219
00220 BoostWgt& operator+= (const BoostWgt&);
00221 BoostWgt& operator*= (REAL);
00222 BoostWgt operator- () const;
00223 #ifndef NDEBUG
00224 bool operator== (const BoostWgt& w) const {
00225 return (lm == w.lm && lm_wgt == w.lm_wgt);
00226 }
00227 #endif
00228 };
00230 };
00231
00232 struct _boost_gd {
00233 Boosting* b;
00234 explicit _boost_gd (Boosting* pb) : b(pb) {}
00235
00236 REAL cost () const { return b->cost(); }
00237
00238 Boosting::BoostWgt weight () const {
00239 return Boosting::BoostWgt(b->lm, b->lm_wgt);
00240 }
00241
00242 void set_weight (const Boosting::BoostWgt& bw) const {
00243 #if BOOSTING_OUTPUT_CACHE
00244 b->clear_cache();
00245 #endif
00246 b->lm = bw.models(); b->lm_wgt = bw.weights();
00247 b->n_in_agg = b->lm.size();
00248 assert(b->lm.size() == b->lm_wgt.size());
00249 }
00250
00251 Boosting::BoostWgt gradient () const {
00252 std::vector<pLearnModel> lm = b->lm;
00253 std::vector<REAL> wgt(lm.size(), 0);
00254
00255 lm.push_back(b->train_with_smpwgt(b->sample_weight()));
00256 wgt.push_back(-1);
00257 return Boosting::BoostWgt(lm, wgt);
00258 }
00259
00260 bool stop_opt (UINT step, REAL cst) const {
00261 return (step >= b->max_n_model || cst < b->min_cst);
00262 }
00263 };
00264
00265 namespace op {
00266
00267 template <typename R>
00268 R inner_product (const Boosting::BoostWgt& w1, const Boosting::BoostWgt& w2) {
00269 #ifndef NDEBUG
00270 std::vector<REAL> w1t(w1.size()); w1t.back() = -1;
00271 assert(w1.weights() == w1t);
00272 std::vector<REAL> w2t(w2.size()); w2t.back() = -1;
00273 assert(w2.weights() == w2t);
00274 #endif
00275 LearnModel& g1 = *w1.models().back();
00276 LearnModel& g2 = *w2.models().back();
00277
00278 UINT n = g1.train_data()->size();
00279 assert(n == g2.train_data()->size());
00280 R sum = 0;
00281 for (UINT i = 0; i < n; ++i)
00282 sum += g1.get_output(i)[0] * g2.get_output(i)[0];
00283 return sum / n;
00284 }
00285
00286 }
00287 }
00288
00289 #ifdef __BOOSTING_H__
00290 #warning "This header file may conflict with another `boosting.h' file."
00291 #endif
00292 #define __BOOSTING_H__
00293 #endif