cgboost.h

Go to the documentation of this file.
00001 // -*- C++ -*-
00002 #ifndef __LEMGA_AGGREGATING_CGBOOST_H__
00003 #define __LEMGA_AGGREGATING_CGBOOST_H__
00004 
00011 #include "boosting.h"
00012 
00013 namespace lemga {
00014 
00015 struct _boost_cg;
00016 
00036 class CGBoost : public Boosting {
00037     std::vector<std::vector<REAL> > all_wgts;
00038     friend struct _boost_cg;
00039 
00040 protected:
00041     /* only valid within training */
00042     std::vector<REAL> ncd,  
00043         cgd;  
00044 
00045 public:
00046     explicit CGBoost (bool cvx = false, const cost::Cost& c = cost::_cost)
00047         : Boosting(cvx, c) {}
00048     explicit CGBoost (std::istream& is) { is >> *this; }
00049 
00050     virtual const id_t& id () const;
00051     virtual CGBoost* create () const { return new CGBoost(); }
00052     virtual CGBoost* clone () const { return new CGBoost(*this); }
00053 
00054     virtual bool set_aggregation_size (UINT);
00055     virtual void initialize ();
00056     virtual REAL train ();
00057     virtual REAL train_gd ();
00058 
00059 protected:
00061     std::vector<REAL> cur_err;
00062 
00063     virtual REAL linear_weight (const DataWgt&, const LearnModel&);
00064     virtual void linear_smpwgt (DataWgt&);
00065 
00066     virtual bool serialize (std::ostream&, ver_list&) const;
00067     virtual bool unserialize (std::istream&, ver_list&,
00068                               const id_t& = empty_id);
00069 };
00070 
00071 struct _boost_cg : public _boost_gd {
00072     CGBoost* cg;
00073     _boost_cg (CGBoost* pc) : _boost_gd(pc), cg(pc) {}
00074 
00075     void set_weight (const Boosting::BoostWgt& bw) const {
00076         _boost_gd::set_weight(bw);
00077         assert(cg->n_in_agg == bw.size() && cg->n_in_agg == cg->lm_wgt.size());
00078 
00079         // save weights to all_wgts
00080         if (cg->n_in_agg == 0) return;
00081         const UINT n = cg->n_in_agg - 1;
00082         if (n < cg->all_wgts.size())
00083             cg->all_wgts[n] = cg->lm_wgt;
00084         else {
00085             assert(n == cg->all_wgts.size()); // allow size inc <= 1
00086             cg->all_wgts.push_back(cg->lm_wgt);
00087         }
00088     }
00089 };
00090 
00091 } // namespace lemga
00092 
00093 #ifdef  __CGBOOST_H__
00094 #warning "This header file may conflict with another `cgboost.h' file."
00095 #endif
00096 #define __CGBOOST_H__
00097 #endif

Generated on Mon Jan 9 23:43:23 2006 for LEMGA by  doxygen 1.4.6