crossval.h

Go to the documentation of this file.
00001 // -*- C++ -*-
00002 #ifndef __LEMGA_CROSSVAL_H__
00003 #define __LEMGA_CROSSVAL_H__
00004 
00011 #include <assert.h>
00012 #include "shared_ptr.h"
00013 #include "learnmodel.h"
00014 
00015 namespace lemga {
00016 
00021 class CrossVal : public LearnModel {
00022 protected:
00023     bool fullset;                
00024     std::vector<pcLearnModel> lm;
00025     std::vector<REAL> err;       
00026     UINT n_rounds;               
00027     pLearnModel best_lm;         
00028     int best;                    
00029 
00030 
00031 
00032 public:
00033     CrossVal () : fullset(true), n_rounds(1), best(-1) {}
00034     CrossVal (const CrossVal&);
00035     const CrossVal& operator= (const CrossVal&);
00036 
00037     virtual CrossVal* create () const = 0;
00038     virtual CrossVal* clone () const = 0;
00039 
00041     void add_model (const LearnModel&);
00043     UINT size () const { assert(lm.size() == err.size()); return lm.size(); }
00045     const LearnModel& model (UINT n) const {
00046         assert(n < size() && lm[n] != 0); return *lm[n]; }
00047 
00049     UINT rounds () const { return n_rounds; }
00051     void set_rounds (UINT r) { assert(r > 0); n_rounds = r; }
00053     bool full_train () const { return fullset; }
00054     void set_full_train (bool f = true) { fullset = f; }
00055 
00056     virtual void set_train_data (const pDataSet&, const pDataWgt& = 0);
00057     virtual void train ();
00058     virtual void reset ();
00059     virtual Output operator() (const Input& x) const {
00060         assert(best >= 0 && best_lm != 0);
00061         return (*best_lm)(x); }
00062     virtual Output get_output (UINT i) const {
00063         assert(best >= 0 && best_lm != 0 && ptd == best_lm->train_data());
00064         return best_lm->get_output(i); }
00065     virtual REAL margin_norm () const {
00066         assert(best >= 0 && best_lm != 0);
00067         return best_lm->margin_norm(); }
00068     virtual REAL margin_of (const Input& x, const Output& y) const {
00069         assert(best >= 0 && best_lm != 0);
00070         return best_lm->margin_of(x, y); }
00071     virtual REAL margin (UINT i) const {
00072         assert(best >= 0 && best_lm != 0 && ptd == best_lm->train_data());
00073         return best_lm->margin(i); }
00074 
00076     REAL error (UINT n) const {
00077         assert(n < size() && err[n] >= 0); return err[n]; }
00079     const LearnModel& best_model () const {
00080         assert(best >= 0);
00081         return best_lm? *best_lm : *lm[best]; }
00082 
00083 protected:
00085     virtual std::vector<REAL> cv_round () const = 0;
00086 
00087     virtual bool serialize (std::ostream&, ver_list&) const;
00088     virtual bool unserialize (std::istream&, ver_list&, const id_t& = NIL_ID);
00089 };
00090 
00092 class vFoldCrossVal : public CrossVal {
00093 public:
00094     vFoldCrossVal (UINT v = 10, UINT r = 0) { set_folds(v, r); }
00095     explicit vFoldCrossVal (std::istream& is) { is >> *this; }
00096 
00097     virtual const id_t& id () const;
00098     virtual vFoldCrossVal* create () const { return new vFoldCrossVal(); }
00099     virtual vFoldCrossVal* clone () const { return new vFoldCrossVal(*this); }
00100 
00101     UINT folds () const { return n_folds; }
00103     void set_folds (UINT v, UINT r = 0) {
00104         assert(v > 1); n_folds = v;
00105         if (r > 0) set_rounds(r);
00106     }
00107 
00108 protected:
00109     UINT n_folds;
00110     virtual std::vector<REAL> cv_round () const;
00111 
00112     virtual bool serialize (std::ostream&, ver_list&) const;
00113     virtual bool unserialize (std::istream&, ver_list&, const id_t& = NIL_ID);
00114 };
00115 typedef vFoldCrossVal kFoldCrossVal;
00116 
00118 class HoldoutCrossVal : public CrossVal {
00119 public:
00120     HoldoutCrossVal (REAL p = 1.0/6, UINT r = 0) { set_holdout(p, r); }
00121     explicit HoldoutCrossVal (std::istream& is) { is >> *this; }
00122 
00123     virtual const id_t& id () const;
00124     virtual HoldoutCrossVal* create () const { return new HoldoutCrossVal(); }
00125     virtual HoldoutCrossVal* clone () const {
00126         return new HoldoutCrossVal(*this); }
00127 
00128     REAL holdout () const { return p_test; }
00130     void set_holdout (REAL p, UINT r = 0) {
00131         assert(p > 0 && p < 0.9); p_test = p;
00132         if (r > 0) set_rounds(r);
00133     }
00134 
00135 protected:
00136     REAL p_test;
00137     virtual std::vector<REAL> cv_round () const;
00138 
00139     virtual bool serialize (std::ostream&, ver_list&) const;
00140     virtual bool unserialize (std::istream&, ver_list&, const id_t& = NIL_ID);
00141 };
00142 
00143 } // namespace lemga
00144 
00145 #ifdef  __CROSSVAL_H__
00146 #warning "This header file may conflict with another `crossval.h' file."
00147 #endif
00148 #define __CROSSVAL_H__
00149 #endif

Generated on Wed Nov 8 08:15:20 2006 for LEMGA by  doxygen 1.4.6