crossval.cpp

Go to the documentation of this file.
00001 
00005 #include <algorithm>
00006 #include "vectorop.h"
00007 #include "random.h"
00008 #include "crossval.h"
00009 
00010 REGISTER_CREATOR2(lemga::vFoldCrossVal, vfold);
00011 REGISTER_CREATOR2(lemga::HoldoutCrossVal, holdout);
00012 
00013 namespace lemga {
00014 
00015 CrossVal::CrossVal (const CrossVal& cv)
00016     : LearnModel(cv), fullset(cv.fullset), lm(cv.lm), err(cv.err),
00017       n_rounds(cv.n_rounds), best(cv.best)
00018 {
00019     best_lm = 0;
00020     if (cv.best_lm != 0) {
00021         best_lm = cv.best_lm->clone();
00022         assert(best >= 0 && best_lm->id() == lm[best]->id());
00023     }
00024 }
00025 
00026 const CrossVal& CrossVal::operator= (const CrossVal& cv) {
00027     if (&cv == this) return *this;
00028 
00029     LearnModel::operator=(cv);
00030     fullset  = cv.fullset;
00031     lm       = cv.lm;
00032     err      = cv.err;
00033     n_rounds = cv.n_rounds;
00034     best     = cv.best;
00035 
00036     best_lm = 0;
00037     if (cv.best_lm != 0) {
00038         best_lm = cv.best_lm->clone();
00039         assert(best >= 0 && best_lm->id() == lm[best]->id());
00040     }
00041 
00042     return *this;
00043 }
00044 
00045 bool CrossVal::serialize (std::ostream& os, ver_list& vl) const {
00046     SERIALIZE_PARENT(LearnModel, os, vl, 1);
00047 
00048     UINT n = size();
00049     if (!(os << n << ' ' << best << ' ' << (best_lm != 0) << '\n'))
00050         return false;
00051     for (UINT i = 0; i < n; ++i)
00052         if (!(os << *lm[i])) return false;
00053     for (UINT i = 0; i < n; ++i)
00054         if (!(os << err[i] << ' ')) return false;
00055     os << '\n';
00056     if (best_lm != 0) {
00057         assert(best >= 0 && best_lm->id() == lm[best]->id());
00058         if (!(os << *best_lm)) return false;
00059     }
00060     return (os << n_rounds << ' ' << fullset << '\n');
00061 }
00062 
00063 bool CrossVal::unserialize (std::istream& is, ver_list& vl, const id_t& d) {
00064     assert(d == NIL_ID);
00065     UNSERIALIZE_PARENT(LearnModel, is, vl, 1, v);
00066     assert(v > 0);
00067 
00068     UINT n;
00069     bool trained;
00070     if (!(is >> n >> best >> trained)) return false;
00071     if (best < -1 || best >= (int) n) return false;
00072 
00073     lm.clear(); err.resize(n);
00074     for (UINT i = 0; i < n; ++i) {
00075         LearnModel* p = (LearnModel*) Object::create(is);
00076         lm.push_back(p);
00077         if (p == 0 || !valid_dimensions(*p)) return false;
00078     }
00079     for (UINT i = 0; i < n; ++i)
00080         if (!(is >> err[i])) return false;
00081 
00082     best_lm = 0;
00083     if (trained) {
00084         LearnModel* p = (LearnModel*) Object::create(is);
00085         best_lm = p;
00086         if (p == 0 || !exact_dimensions(*p)) return false;
00087         if (best < 0 || p->id() != lm[best]->id()) return false;
00088     }
00089 
00090     return (is >> n_rounds >> fullset) && (n_rounds > 0);
00091 }
00092 
00093 void CrossVal::add_model (const LearnModel& l) {
00094     set_dimensions(l);
00095     lm.push_back(l.clone());
00096     err.push_back(-1);
00097 }
00098 
00099 void CrossVal::set_train_data (const pDataSet& pd, const pDataWgt& pw) {
00100     assert(pw == 0); // cannot deal with sample weights
00101     LearnModel::set_train_data(pd, 0);
00102     if (best_lm != 0) {
00103         assert(best >= 0 && best_lm->id() == lm[best]->id());
00104         best_lm->set_train_data(pd, 0);
00105     }
00106 }
00107 
00108 void CrossVal::train () {
00109     assert(n_rounds > 0 && ptd != 0 && ptw == 0);
00110     best_lm = 0;
00111 
00112     std::fill(err.begin(), err.end(), 0);
00113     using namespace op;
00114     for (UINT r = 0; r < n_rounds; ++r)
00115         err += cv_round();
00116     err *= 1 / (REAL) n_rounds;
00117 
00118     best = std::min_element(err.begin(), err.end()) - err.begin();
00119     if (fullset) {
00120         best_lm = lm[best]->clone();
00121         best_lm->initialize();
00122         best_lm->set_train_data(ptd);
00123         best_lm->train();
00124         set_dimensions(*best_lm);
00125     }
00126 }
00127 
00128 void CrossVal::reset () {
00129     LearnModel::reset();
00130     std::fill(err.begin(), err.end(), -1);
00131     best_lm = 0; best = -1;
00132 }
00133 
00134 bool vFoldCrossVal::serialize (std::ostream& os, ver_list& vl) const {
00135     SERIALIZE_PARENT(CrossVal, os, vl, 1);
00136     return (os << n_folds << '\n');
00137 }
00138 
00139 bool
00140 vFoldCrossVal::unserialize (std::istream& is, ver_list& vl, const id_t& d) {
00141     if (d != id() && d != NIL_ID) return false;
00142     UNSERIALIZE_PARENT(CrossVal, is, vl, 1, v);
00143     assert(v > 0);
00144     return (is >> n_folds) && (n_folds > 1);
00145 }
00146 
00147 std::vector<REAL> vFoldCrossVal::cv_round () const {
00148     assert(ptd != 0);
00149     UINT n = size(), ds = ptd->size();
00150     std::vector<REAL> cve(n, 0);
00151 
00152     // get a random index
00153     std::vector<UINT> perm(ds);
00154     for (UINT i = 0; i < ds; ++i) perm[i] = i;
00155     std::random_shuffle(perm.begin(), perm.end());
00156 
00157     UINT b, e = 0;
00158     for (UINT f = 1; f <= n_folds; ++f) {
00159         // [b,e) is the index range for the testing set
00160         b = e; e = f * ds / n_folds;
00161         assert(e-b == ds/n_folds || e-b == (ds+n_folds-1)/n_folds);
00162 
00163         // generate the training and testing sets
00164         DataSet *p_tr = new DataSet();
00165         DataSet *p_te = new DataSet();
00167         for (UINT i = 0; i < b; ++i)
00168             p_tr->append(ptd->x(perm[i]), ptd->y(perm[i]));
00169         for (UINT i = b; i < e; ++i)
00170             p_te->append(ptd->x(perm[i]), ptd->y(perm[i]));
00171         for (UINT i = e; i < ds; ++i)
00172             p_tr->append(ptd->x(perm[i]), ptd->y(perm[i]));
00173         pDataSet ptr = p_tr, pte = p_te;
00174 
00175         // go over all candidates and collect the errors
00176         for (UINT i = 0; i < n; ++i) {
00177             pLearnModel p = lm[i]->clone();
00178             p->set_train_data(ptr);
00179             p->train();
00180             // which error to collect? let's assume classification error
00181             cve[i] += p->test_c_error(pte) * pte->size();
00182         }
00183     }
00184     using namespace op;
00185     cve *= 1 / (REAL) ds;
00186 
00187     return cve;
00188 }
00189 
00190 bool HoldoutCrossVal::serialize (std::ostream& os, ver_list& vl) const {
00191     SERIALIZE_PARENT(CrossVal, os, vl, 1);
00192     return (os << p_test << '\n');
00193 }
00194 
00195 bool
00196 HoldoutCrossVal::unserialize (std::istream& is, ver_list& vl, const id_t& d) {
00197     if (d != id() && d != NIL_ID) return false;
00198     UNSERIALIZE_PARENT(CrossVal, is, vl, 1, v);
00199     assert(v > 0);
00200     return (is >> p_test) && (p_test > 0 && p_test < 0.9);
00201 }
00202 
00203 std::vector<REAL> HoldoutCrossVal::cv_round () const {
00204     assert(ptd != 0);
00205     const UINT n = ptd->size();
00206     UINT k = UINT(n * p_test + 0.5); if (k < 1) k = 1;
00207     DataSet *p_tr = new DataSet();
00208     DataSet *p_te = new DataSet();
00209 
00210     // (n,k): choosing k examples from n ones.
00211     // To generate (n,k), we pick the 1st example with probability k/n,
00212     // and do (n-1,k-1) if the example is picked, or (n-1,k) if it is not.
00213     // Note: we may break out when k reaches 0 to save some randu() calls.
00214     for (UINT i = 0; i < n; ++i) {
00215         UINT toss = UINT(randu() * (n-i));
00216         assert(0 <= toss && toss < n-i);
00217         if (toss < k) {
00218             p_te->append(ptd->x(i), ptd->y(i));
00219             --k;
00220         } else
00221             p_tr->append(ptd->x(i), ptd->y(i));
00222     }
00223     assert(k == 0);
00224 
00225     pDataSet ptr = p_tr, pte = p_te;
00226     const UINT lms = size();
00227     std::vector<REAL> cve(lms);
00228     // go over all candidates and collect the errors
00229     for (UINT i = 0; i < lms; ++i) {
00230         pLearnModel p = lm[i]->clone();
00231         p->set_train_data(ptr);
00232         p->train();
00233         // which error to collect? let's assume classification error
00234         cve[i] = p->test_c_error(pte);
00235     }
00236 
00237     return cve;
00238 }
00239 
00240 } // namespace lemga

Generated on Wed Nov 8 08:15:20 2006 for LEMGA by  doxygen 1.4.6