00001 // -*- C++ -*- 00002 #ifndef __LEMGA_AGGREGATING_CGBOOST_H__ 00003 #define __LEMGA_AGGREGATING_CGBOOST_H__ 00004 00011 #include "boosting.h" 00012 00013 namespace lemga { 00014 00015 struct _boost_cg; 00016 00036 class CGBoost : public Boosting { 00037 std::vector<std::vector<REAL> > all_wgts; 00038 friend struct _boost_cg; 00039 00040 protected: 00041 /* only valid within training */ 00042 std::vector<REAL> ncd, 00043 cgd; 00044 00045 public: 00046 explicit CGBoost (bool cvx = false, const cost::Cost& c = cost::_cost) 00047 : Boosting(cvx, c) {} 00048 explicit CGBoost (std::istream& is) { is >> *this; } 00049 00050 virtual const id_t& id () const; 00051 virtual CGBoost* create () const { return new CGBoost(); } 00052 virtual CGBoost* clone () const { return new CGBoost(*this); } 00053 00054 virtual bool set_aggregation_size (UINT); 00055 virtual void initialize (); 00056 virtual REAL train (); 00057 virtual REAL train_gd (); 00058 00059 protected: 00061 std::vector<REAL> cur_err; 00062 00063 virtual REAL linear_weight (const DataWgt&, const LearnModel&); 00064 virtual void linear_smpwgt (DataWgt&); 00065 00066 virtual bool serialize (std::ostream&, ver_list&) const; 00067 virtual bool unserialize (std::istream&, ver_list&, 00068 const id_t& = empty_id); 00069 }; 00070 00071 struct _boost_cg : public _boost_gd { 00072 CGBoost* cg; 00073 _boost_cg (CGBoost* pc) : _boost_gd(pc), cg(pc) {} 00074 00075 void set_weight (const Boosting::BoostWgt& bw) const { 00076 _boost_gd::set_weight(bw); 00077 assert(cg->n_in_agg == bw.size() && cg->n_in_agg == cg->lm_wgt.size()); 00078 00079 // save weights to all_wgts 00080 if (cg->n_in_agg == 0) return; 00081 const UINT n = cg->n_in_agg - 1; 00082 if (n < cg->all_wgts.size()) 00083 cg->all_wgts[n] = cg->lm_wgt; 00084 else { 00085 assert(n == cg->all_wgts.size()); // allow size inc <= 1 00086 cg->all_wgts.push_back(cg->lm_wgt); 00087 } 00088 } 00089 }; 00090 00091 } // namespace lemga 00092 00093 #ifdef __CGBOOST_H__ 00094 #warning "This header file may conflict with another `cgboost.h' file." 00095 #endif 00096 #define __CGBOOST_H__ 00097 #endif