00001
00005 #include <algorithm>
00006 #include "vectorop.h"
00007 #include "random.h"
00008 #include "crossval.h"
00009
00010 REGISTER_CREATOR2(lemga::vFoldCrossVal, vfold);
00011 REGISTER_CREATOR2(lemga::HoldoutCrossVal, holdout);
00012
00013 namespace lemga {
00014
00015 CrossVal::CrossVal (const CrossVal& cv)
00016 : LearnModel(cv), fullset(cv.fullset), lm(cv.lm), err(cv.err),
00017 n_rounds(cv.n_rounds), best(cv.best)
00018 {
00019 best_lm = 0;
00020 if (cv.best_lm != 0) {
00021 best_lm = cv.best_lm->clone();
00022 assert(best >= 0 && best_lm->id() == lm[best]->id());
00023 }
00024 }
00025
00026 const CrossVal& CrossVal::operator= (const CrossVal& cv) {
00027 if (&cv == this) return *this;
00028
00029 LearnModel::operator=(cv);
00030 fullset = cv.fullset;
00031 lm = cv.lm;
00032 err = cv.err;
00033 n_rounds = cv.n_rounds;
00034 best = cv.best;
00035
00036 best_lm = 0;
00037 if (cv.best_lm != 0) {
00038 best_lm = cv.best_lm->clone();
00039 assert(best >= 0 && best_lm->id() == lm[best]->id());
00040 }
00041
00042 return *this;
00043 }
00044
00045 bool CrossVal::serialize (std::ostream& os, ver_list& vl) const {
00046 SERIALIZE_PARENT(LearnModel, os, vl, 1);
00047
00048 UINT n = size();
00049 if (!(os << n << ' ' << best << ' ' << (best_lm != 0) << '\n'))
00050 return false;
00051 for (UINT i = 0; i < n; ++i)
00052 if (!(os << *lm[i])) return false;
00053 for (UINT i = 0; i < n; ++i)
00054 if (!(os << err[i] << ' ')) return false;
00055 os << '\n';
00056 if (best_lm != 0) {
00057 assert(best >= 0 && best_lm->id() == lm[best]->id());
00058 if (!(os << *best_lm)) return false;
00059 }
00060 return (os << n_rounds << ' ' << fullset << '\n');
00061 }
00062
00063 bool CrossVal::unserialize (std::istream& is, ver_list& vl, const id_t& d) {
00064 assert(d == NIL_ID);
00065 UNSERIALIZE_PARENT(LearnModel, is, vl, 1, v);
00066 assert(v > 0);
00067
00068 UINT n;
00069 bool trained;
00070 if (!(is >> n >> best >> trained)) return false;
00071 if (best < -1 || best >= (int) n) return false;
00072
00073 lm.clear(); err.resize(n);
00074 for (UINT i = 0; i < n; ++i) {
00075 LearnModel* p = (LearnModel*) Object::create(is);
00076 lm.push_back(p);
00077 if (p == 0 || !valid_dimensions(*p)) return false;
00078 }
00079 for (UINT i = 0; i < n; ++i)
00080 if (!(is >> err[i])) return false;
00081
00082 best_lm = 0;
00083 if (trained) {
00084 LearnModel* p = (LearnModel*) Object::create(is);
00085 best_lm = p;
00086 if (p == 0 || !exact_dimensions(*p)) return false;
00087 if (best < 0 || p->id() != lm[best]->id()) return false;
00088 }
00089
00090 return (is >> n_rounds >> fullset) && (n_rounds > 0);
00091 }
00092
00093 void CrossVal::add_model (const LearnModel& l) {
00094 set_dimensions(l);
00095 lm.push_back(l.clone());
00096 err.push_back(-1);
00097 }
00098
00099 void CrossVal::set_train_data (const pDataSet& pd, const pDataWgt& pw) {
00100 assert(pw == 0);
00101 LearnModel::set_train_data(pd, 0);
00102 if (best_lm != 0) {
00103 assert(best >= 0 && best_lm->id() == lm[best]->id());
00104 best_lm->set_train_data(pd, 0);
00105 }
00106 }
00107
00108 void CrossVal::train () {
00109 assert(n_rounds > 0 && ptd != 0 && ptw == 0);
00110 best_lm = 0;
00111
00112 std::fill(err.begin(), err.end(), 0);
00113 using namespace op;
00114 for (UINT r = 0; r < n_rounds; ++r)
00115 err += cv_round();
00116 err *= 1 / (REAL) n_rounds;
00117
00118 best = std::min_element(err.begin(), err.end()) - err.begin();
00119 if (fullset) {
00120 best_lm = lm[best]->clone();
00121 best_lm->initialize();
00122 best_lm->set_train_data(ptd);
00123 best_lm->train();
00124 set_dimensions(*best_lm);
00125 }
00126 }
00127
00128 void CrossVal::reset () {
00129 LearnModel::reset();
00130 std::fill(err.begin(), err.end(), -1);
00131 best_lm = 0; best = -1;
00132 }
00133
00134 bool vFoldCrossVal::serialize (std::ostream& os, ver_list& vl) const {
00135 SERIALIZE_PARENT(CrossVal, os, vl, 1);
00136 return (os << n_folds << '\n');
00137 }
00138
00139 bool
00140 vFoldCrossVal::unserialize (std::istream& is, ver_list& vl, const id_t& d) {
00141 if (d != id() && d != NIL_ID) return false;
00142 UNSERIALIZE_PARENT(CrossVal, is, vl, 1, v);
00143 assert(v > 0);
00144 return (is >> n_folds) && (n_folds > 1);
00145 }
00146
00147 std::vector<REAL> vFoldCrossVal::cv_round () const {
00148 assert(ptd != 0);
00149 UINT n = size(), ds = ptd->size();
00150 std::vector<REAL> cve(n, 0);
00151
00152
00153 std::vector<UINT> perm(ds);
00154 for (UINT i = 0; i < ds; ++i) perm[i] = i;
00155 std::random_shuffle(perm.begin(), perm.end());
00156
00157 UINT b, e = 0;
00158 for (UINT f = 1; f <= n_folds; ++f) {
00159
00160 b = e; e = f * ds / n_folds;
00161 assert(e-b == ds/n_folds || e-b == (ds+n_folds-1)/n_folds);
00162
00163
00164 DataSet *p_tr = new DataSet();
00165 DataSet *p_te = new DataSet();
00167 for (UINT i = 0; i < b; ++i)
00168 p_tr->append(ptd->x(perm[i]), ptd->y(perm[i]));
00169 for (UINT i = b; i < e; ++i)
00170 p_te->append(ptd->x(perm[i]), ptd->y(perm[i]));
00171 for (UINT i = e; i < ds; ++i)
00172 p_tr->append(ptd->x(perm[i]), ptd->y(perm[i]));
00173 pDataSet ptr = p_tr, pte = p_te;
00174
00175
00176 for (UINT i = 0; i < n; ++i) {
00177 pLearnModel p = lm[i]->clone();
00178 p->set_train_data(ptr);
00179 p->train();
00180
00181 cve[i] += p->test_c_error(pte) * pte->size();
00182 }
00183 }
00184 using namespace op;
00185 cve *= 1 / (REAL) ds;
00186
00187 return cve;
00188 }
00189
00190 bool HoldoutCrossVal::serialize (std::ostream& os, ver_list& vl) const {
00191 SERIALIZE_PARENT(CrossVal, os, vl, 1);
00192 return (os << p_test << '\n');
00193 }
00194
00195 bool
00196 HoldoutCrossVal::unserialize (std::istream& is, ver_list& vl, const id_t& d) {
00197 if (d != id() && d != NIL_ID) return false;
00198 UNSERIALIZE_PARENT(CrossVal, is, vl, 1, v);
00199 assert(v > 0);
00200 return (is >> p_test) && (p_test > 0 && p_test < 0.9);
00201 }
00202
00203 std::vector<REAL> HoldoutCrossVal::cv_round () const {
00204 assert(ptd != 0);
00205 const UINT n = ptd->size();
00206 UINT k = UINT(n * p_test + 0.5); if (k < 1) k = 1;
00207 DataSet *p_tr = new DataSet();
00208 DataSet *p_te = new DataSet();
00209
00210
00211
00212
00213
00214 for (UINT i = 0; i < n; ++i) {
00215 UINT toss = UINT(randu() * (n-i));
00216 assert(0 <= toss && toss < n-i);
00217 if (toss < k) {
00218 p_te->append(ptd->x(i), ptd->y(i));
00219 --k;
00220 } else
00221 p_tr->append(ptd->x(i), ptd->y(i));
00222 }
00223 assert(k == 0);
00224
00225 pDataSet ptr = p_tr, pte = p_te;
00226 const UINT lms = size();
00227 std::vector<REAL> cve(lms);
00228
00229 for (UINT i = 0; i < lms; ++i) {
00230 pLearnModel p = lm[i]->clone();
00231 p->set_train_data(ptr);
00232 p->train();
00233
00234 cve[i] = p->test_c_error(pte);
00235 }
00236
00237 return cve;
00238 }
00239
00240 }