00001
00002 #ifndef __LEMGA_CROSSVAL_H__
00003 #define __LEMGA_CROSSVAL_H__
00004
00011 #include <assert.h>
00012 #include "shared_ptr.h"
00013 #include "learnmodel.h"
00014
00015 namespace lemga {
00016
00021 class CrossVal : public LearnModel {
00022 protected:
00023 bool fullset;
00024 std::vector<pcLearnModel> lm;
00025 std::vector<REAL> err;
00026 UINT n_rounds;
00027 pLearnModel best_lm;
00028 int best;
00029
00030
00031
00032 public:
00033 CrossVal () : fullset(true), n_rounds(1), best(-1) {}
00034 CrossVal (const CrossVal&);
00035 const CrossVal& operator= (const CrossVal&);
00036
00037 virtual CrossVal* create () const = 0;
00038 virtual CrossVal* clone () const = 0;
00039
00041 void add_model (const LearnModel&);
00043 UINT size () const { assert(lm.size() == err.size()); return lm.size(); }
00045 const LearnModel& model (UINT n) const {
00046 assert(n < size() && lm[n] != 0); return *lm[n]; }
00047
00049 UINT rounds () const { return n_rounds; }
00051 void set_rounds (UINT r) { assert(r > 0); n_rounds = r; }
00053 bool full_train () const { return fullset; }
00054 void set_full_train (bool f = true) { fullset = f; }
00055
00056 virtual void set_train_data (const pDataSet&, const pDataWgt& = 0);
00057 virtual void train ();
00058 virtual void reset ();
00059 virtual Output operator() (const Input& x) const {
00060 assert(best >= 0 && best_lm != 0);
00061 return (*best_lm)(x); }
00062 virtual Output get_output (UINT i) const {
00063 assert(best >= 0 && best_lm != 0 && ptd == best_lm->train_data());
00064 return best_lm->get_output(i); }
00065 virtual REAL margin_norm () const {
00066 assert(best >= 0 && best_lm != 0);
00067 return best_lm->margin_norm(); }
00068 virtual REAL margin_of (const Input& x, const Output& y) const {
00069 assert(best >= 0 && best_lm != 0);
00070 return best_lm->margin_of(x, y); }
00071 virtual REAL margin (UINT i) const {
00072 assert(best >= 0 && best_lm != 0 && ptd == best_lm->train_data());
00073 return best_lm->margin(i); }
00074
00076 REAL error (UINT n) const {
00077 assert(n < size() && err[n] >= 0); return err[n]; }
00079 const LearnModel& best_model () const {
00080 assert(best >= 0);
00081 return best_lm? *best_lm : *lm[best]; }
00082
00083 protected:
00085 virtual std::vector<REAL> cv_round () const = 0;
00086
00087 virtual bool serialize (std::ostream&, ver_list&) const;
00088 virtual bool unserialize (std::istream&, ver_list&, const id_t& = NIL_ID);
00089 };
00090
00092 class vFoldCrossVal : public CrossVal {
00093 public:
00094 vFoldCrossVal (UINT v = 10, UINT r = 0) { set_folds(v, r); }
00095 explicit vFoldCrossVal (std::istream& is) { is >> *this; }
00096
00097 virtual const id_t& id () const;
00098 virtual vFoldCrossVal* create () const { return new vFoldCrossVal(); }
00099 virtual vFoldCrossVal* clone () const { return new vFoldCrossVal(*this); }
00100
00101 UINT folds () const { return n_folds; }
00103 void set_folds (UINT v, UINT r = 0) {
00104 assert(v > 1); n_folds = v;
00105 if (r > 0) set_rounds(r);
00106 }
00107
00108 protected:
00109 UINT n_folds;
00110 virtual std::vector<REAL> cv_round () const;
00111
00112 virtual bool serialize (std::ostream&, ver_list&) const;
00113 virtual bool unserialize (std::istream&, ver_list&, const id_t& = NIL_ID);
00114 };
00115 typedef vFoldCrossVal kFoldCrossVal;
00116
00118 class HoldoutCrossVal : public CrossVal {
00119 public:
00120 HoldoutCrossVal (REAL p = 1.0/6, UINT r = 0) { set_holdout(p, r); }
00121 explicit HoldoutCrossVal (std::istream& is) { is >> *this; }
00122
00123 virtual const id_t& id () const;
00124 virtual HoldoutCrossVal* create () const { return new HoldoutCrossVal(); }
00125 virtual HoldoutCrossVal* clone () const {
00126 return new HoldoutCrossVal(*this); }
00127
00128 REAL holdout () const { return p_test; }
00130 void set_holdout (REAL p, UINT r = 0) {
00131 assert(p > 0 && p < 0.9); p_test = p;
00132 if (r > 0) set_rounds(r);
00133 }
00134
00135 protected:
00136 REAL p_test;
00137 virtual std::vector<REAL> cv_round () const;
00138
00139 virtual bool serialize (std::ostream&, ver_list&) const;
00140 virtual bool unserialize (std::istream&, ver_list&, const id_t& = NIL_ID);
00141 };
00142
00143 }
00144
00145 #ifdef __CROSSVAL_H__
00146 #warning "This header file may conflict with another `crossval.h' file."
00147 #endif
00148 #define __CROSSVAL_H__
00149 #endif