stump.cpp

Go to the documentation of this file.
00001 
00005 #include <assert.h>
00006 #include <cmath>
00007 #include <algorithm>
00008 #include <vector>
00009 #include <map>
00010 #include "stump.h"
00011 
00012 REGISTER_CREATOR(lemga::Stump);
00013 
00014 namespace lemga {
00015 
00016 bool Stump::serialize (std::ostream& os, ver_list& vl) const {
00017     SERIALIZE_PARENT(LearnModel, os, vl, 2);
00018     return (os << idx << ' ' << bd1 << ' ' << bd2 << ' '
00019                << (dir? 'P':'N') << '\n');
00020 }
00021 
00022 bool Stump::unserialize (std::istream& is, ver_list& vl, const id_t& d) {
00023     if (d != id() && d != empty_id) return false;
00024     UNSERIALIZE_PARENT(LearnModel, is, vl, 2, v);
00025 
00026     if (v == 0) { /* Take care of _n_in and _n_out */
00027         if (!(is >> _n_in)) return false;
00028         assert(_n_out == 1);
00029     }
00030     /* common part for ver 0, 1, and 2 */
00031     if (!(is >> idx >> bd1)) return false;
00032     bd2 = bd1;
00033     if (v >= 2)
00034         if (!(is >> bd2) || bd1 > bd2) return false;
00035     char c;
00036     if (!(is >> c) || (c != 'P' && c != 'N')) return false;
00037     dir = (c == 'P');
00038     return true;
00039 }
00040 
00041 typedef std::map<REAL,REAL>::iterator MI;
00042 
00043 REAL Stump::train_1d (const std::vector<REAL>& x, const std::vector<REAL>& yw)
00044 {
00045     const UINT N = x.size();
00046     assert(yw.size() == N);
00047 
00048     // combine examples with same input, and sort them
00049     std::map<REAL,REAL> xy;
00050     for (UINT i = 0; i < N; ++i)
00051         xy[x[i]] += yw[i];
00052     UINT n = xy.size();
00053     xy[xy.begin()->first-2] = 1;    // insert a "very small" x
00054     assert(xy.size() == n+1);
00055     for (MI p, pn = xy.begin(); p = pn++, p != xy.end();)
00056         if (p->second > -INFINITESIMAL && p->second < INFINITESIMAL)
00057             xy.erase(p);
00058 
00059     n = xy.size();
00060     xy[xy.rbegin()->first+2] = 1;   // a "very large" x
00061     assert(n > 0 && xy.size() == n+1);
00062     const MI xyb = xy.begin();
00063 
00064     REAL minthr = 0, mine = INFINITY, e = - xyb->second;
00065     REAL cur_x = xyb->first;
00066     for (MI p = xyb; n; --n) {
00067         e += p->second;
00068         assert(p != xyb || e == 0);
00069         REAL nxt_x = (++p)->first;
00070         REAL cur_thr = (cur_x + nxt_x) / 2;
00071         cur_x = nxt_x;
00072 
00073         if (e < mine) {
00074             mine = e; minthr = cur_thr;
00075         }
00076     }
00077 
00078     return minthr;
00079 }
00080 
00081 /* yw_inf is yw[infinity] */
00082 REAL Stump::train_1d (const std::vector<REAL>& x, const std::vector<REAL>& yw,
00083                       REAL yw_inf, bool& ir, bool& mind, REAL& th1, REAL& th2)
00084 {
00085     const UINT N = x.size();
00086     assert(yw.size() == N);
00087 
00088     // combine examples with same input, and sort them
00089     std::map<REAL,REAL> xy;
00090     for (UINT i = 0; i < N; ++i)
00091         xy[x[i]] += yw[i];
00092 #ifndef NDEBUG
00093     UINT n = xy.size();
00094 #endif
00095     xy[xy.begin()->first-2] = 1;    // insert a "very small" x
00096     assert(xy.size() == n+1);
00097     for (MI p, pn = xy.begin(); p = pn++, p != xy.end();)
00098         if (p->second > -INFINITESIMAL && p->second < INFINITESIMAL)
00099             xy.erase(p);
00100 
00101 #ifndef NDEBUG
00102     // check whether a constant function is enough
00103     bool all_the_same = true;
00104     MI pb = xy.begin(); ++pb;
00105     for (MI p = pb; p != xy.end(); ++p)
00106         if (p->second < -INFINITESIMAL) { all_the_same = false; break; }
00107 
00108     if (!all_the_same) {
00109         all_the_same = true;
00110         for (MI p = pb; p != xy.end(); ++p)
00111             if (p->second > INFINITESIMAL) { all_the_same = false; break; }
00112     }
00113 
00114     if (all_the_same)
00115         std::cerr << "Stump: Warning: all y's are the same.\n";
00116 #endif
00117     const MI xyb = xy.begin(), xye = xy.end();
00118 
00119     REAL mine = 0, maxe = 0, e = -1;
00120     MI mint, maxt, p;
00121     for (mint = maxt = p = xyb; p != xye; ++p) {
00122         e += p->second;
00123         assert(p != xyb || e == 0);
00124         if (e < mine) {
00125             mine = e; mint = p;
00126         } else if (e > maxe) {
00127             maxe = e; maxt = p;
00128         }
00129         // we prefer middle indices
00130         if (mint == xyb && e == 0) mint = p;
00131         if (maxt == xyb && e == 0) maxt = p;
00132     }
00133     e = (1-e-yw_inf) / 2;// error of y = sgn(x > -Inf)
00134     mine += e;           // starting with y = -1
00135     maxe = 1 - (maxe+e); // starting with y = 1
00136 
00137     // unify the solution to mind, mine, and mint
00138     if (std::fabs(mine - maxe) < EPSILON) {
00139         MI nxtt = mint; ++nxtt;
00140         mind = (mint != xyb && nxtt != xye);
00141     }
00142     else mind = (mine < maxe);
00143     if (!mind) {
00144         mine = maxe; mint = maxt;
00145     }
00146 
00147     MI nxtt = mint; ++nxtt;
00148     ir = (mint != xyb && nxtt != xye);
00149     assert(!ir || !all_the_same);
00150     th1 = mint->first;
00151     if (ir)
00152         th2 = nxtt->first;
00153     else
00154         th2 = th1 + 2;
00155 
00156     return mine;
00157 }
00158 
00159 /* Find the optimal dimension and threshold */
00160 REAL Stump::train () {
00161     const UINT N = n_samples;
00162     assert(ptd != 0 && ptw != 0 && ptd->size() == N);
00163     assert(n_input() > 0 && n_output() == 1);
00164 
00165     // weight the examples
00166     std::vector<REAL> yw(N);
00167     for (UINT i = 0; i < N; ++i)
00168         yw[i] = ptd->y(i)[0] * (*ptw)[i];
00169 
00170     REAL minerr = 2;
00171     bool minir = true;
00172 
00173     std::vector<REAL> x(N);
00174     for (UINT d = 0; d < _n_in; ++d) {
00175         for (UINT i = 0; i < N; ++i)
00176             x[i] = ptd->x(i)[d];
00177 
00178         bool mind, ir;
00179         REAL th1, th2;
00180         REAL mine = train_1d(x, yw, 0, ir, mind, th1, th2);
00181 
00182         if (mine < minerr) {
00183             minerr = mine;
00184             minir = ir;
00185             dir = mind; idx = d;
00186             bd1 = th1; bd2 = th2;
00187         }
00188     }
00189 
00190     if (!minir)
00191         std::cerr << "Stump: Warning: threshold out of range.\n";
00192     return minerr;
00193 }
00194 
00195 Output Stump::operator() (const Input& x) const {
00196     assert(idx < n_input() && x.size() == n_input());
00197     assert(bd2 >= bd1);
00198     REAL y = x[idx]*2 - (bd1 + bd2);
00199     if (hard || bd1 == bd2)
00200         y = (y < 0)? -1 : 1;
00201     else {
00202         y /= bd2 - bd1;
00203         y = (y<-1)? -1 : (y>1)? 1 : y;
00204     }
00205     return Output(1, dir? y : -y);
00206 }
00207 
00208 } // namespace lemga

Generated on Mon Jan 9 23:43:24 2006 for LEMGA by  doxygen 1.4.6