00001 // -*- C++ -*- 00002 #ifndef __LEMGA_AGGREGATING_CGBOOST_H__ 00003 #define __LEMGA_AGGREGATING_CGBOOST_H__ 00004 00011 #include "boosting.h" 00012 00013 namespace lemga { 00014 00015 struct _boost_cg; 00016 00036 class CGBoost : public Boosting { 00037 std::vector<std::vector<REAL> > all_wgts; 00038 friend struct _boost_cg; 00039 00040 protected: 00041 /* only valid within training */ 00042 std::vector<REAL> ncd, 00043 cgd; 00044 00045 public: 00046 explicit CGBoost (bool cvx = false, const cost::Cost& c = cost::_cost) 00047 : Boosting(cvx, c) {} 00048 CGBoost (const Boosting& s) : Boosting(s) { 00049 const std::vector<REAL>::const_iterator b = lm_wgt.begin(); 00050 for (UINT i = 1; i <= lm_wgt.size(); ++i) 00051 all_wgts.push_back(std::vector<REAL>(b, b+i)); 00052 } 00053 explicit CGBoost (std::istream& is) { is >> *this; } 00054 00055 virtual const id_t& id () const; 00056 virtual CGBoost* create () const { return new CGBoost(); } 00057 virtual CGBoost* clone () const { return new CGBoost(*this); } 00058 00059 virtual bool set_aggregation_size (UINT); 00060 virtual void train (); 00061 virtual void reset (); 00062 00063 protected: 00065 std::vector<REAL> cur_err; 00066 00067 virtual void train_gd (); 00068 virtual REAL linear_weight (const DataWgt&, const LearnModel&); 00069 virtual void linear_smpwgt (DataWgt&); 00070 00071 virtual bool serialize (std::ostream&, ver_list&) const; 00072 virtual bool unserialize (std::istream&, ver_list&, const id_t& = NIL_ID); 00073 }; 00074 00075 struct _boost_cg : public _boost_gd { 00076 CGBoost* cg; 00077 _boost_cg (CGBoost* pc) : _boost_gd(pc), cg(pc) {} 00078 00079 void set_weight (const Boosting::BoostWgt& bw) const { 00080 _boost_gd::set_weight(bw); 00081 assert(cg->n_in_agg == bw.size() && cg->n_in_agg == cg->lm_wgt.size()); 00082 00083 // save weights to all_wgts 00084 if (cg->n_in_agg == 0) return; 00085 const UINT n = cg->n_in_agg - 1; 00086 if (n < cg->all_wgts.size()) 00087 cg->all_wgts[n] = cg->lm_wgt; 00088 else { 00089 assert(n == cg->all_wgts.size()); // allow size inc <= 1 00090 cg->all_wgts.push_back(cg->lm_wgt); 00091 } 00092 } 00093 }; 00094 00095 } // namespace lemga 00096 00097 #ifdef __CGBOOST_H__ 00098 #warning "This header file may conflict with another `cgboost.h' file." 00099 #endif 00100 #define __CGBOOST_H__ 00101 #endif