cgboost.h

Go to the documentation of this file.
00001 // -*- C++ -*-
00002 #ifndef __LEMGA_AGGREGATING_CGBOOST_H__
00003 #define __LEMGA_AGGREGATING_CGBOOST_H__
00004 
00011 #include "boosting.h"
00012 
00013 namespace lemga {
00014 
00015 struct _boost_cg;
00016 
00036 class CGBoost : public Boosting {
00037     std::vector<std::vector<REAL> > all_wgts;
00038     friend struct _boost_cg;
00039 
00040 protected:
00041     /* only valid within training */
00042     std::vector<REAL> ncd,  
00043         cgd;  
00044 
00045 public:
00046     explicit CGBoost (bool cvx = false, const cost::Cost& c = cost::_cost)
00047         : Boosting(cvx, c) {}
00048     CGBoost (const Boosting& s) : Boosting(s) {
00049         const std::vector<REAL>::const_iterator b = lm_wgt.begin();
00050         for (UINT i = 1; i <= lm_wgt.size(); ++i)
00051             all_wgts.push_back(std::vector<REAL>(b, b+i));
00052     }
00053     explicit CGBoost (std::istream& is) { is >> *this; }
00054 
00055     virtual const id_t& id () const;
00056     virtual CGBoost* create () const { return new CGBoost(); }
00057     virtual CGBoost* clone () const { return new CGBoost(*this); }
00058 
00059     virtual bool set_aggregation_size (UINT);
00060     virtual void train ();
00061     virtual void reset ();
00062 
00063 protected:
00065     std::vector<REAL> cur_err;
00066 
00067     virtual void train_gd ();
00068     virtual REAL linear_weight (const DataWgt&, const LearnModel&);
00069     virtual void linear_smpwgt (DataWgt&);
00070 
00071     virtual bool serialize (std::ostream&, ver_list&) const;
00072     virtual bool unserialize (std::istream&, ver_list&, const id_t& = NIL_ID);
00073 };
00074 
00075 struct _boost_cg : public _boost_gd {
00076     CGBoost* cg;
00077     _boost_cg (CGBoost* pc) : _boost_gd(pc), cg(pc) {}
00078 
00079     void set_weight (const Boosting::BoostWgt& bw) const {
00080         _boost_gd::set_weight(bw);
00081         assert(cg->n_in_agg == bw.size() && cg->n_in_agg == cg->lm_wgt.size());
00082 
00083         // save weights to all_wgts
00084         if (cg->n_in_agg == 0) return;
00085         const UINT n = cg->n_in_agg - 1;
00086         if (n < cg->all_wgts.size())
00087             cg->all_wgts[n] = cg->lm_wgt;
00088         else {
00089             assert(n == cg->all_wgts.size()); // allow size inc <= 1
00090             cg->all_wgts.push_back(cg->lm_wgt);
00091         }
00092     }
00093 };
00094 
00095 } // namespace lemga
00096 
00097 #ifdef  __CGBOOST_H__
00098 #warning "This header file may conflict with another `cgboost.h' file."
00099 #endif
00100 #define __CGBOOST_H__
00101 #endif

Generated on Wed Nov 8 08:15:20 2006 for LEMGA by  doxygen 1.4.6