pulse.cpp

Go to the documentation of this file.
00001 
00005 #include <assert.h>
00006 #include <cmath>
00007 #include <vector>
00008 #include <map>
00009 #include <algorithm>
00010 #include "pulse.h"
00011 
00012 REGISTER_CREATOR(lemga::Pulse);
00013 
00014 template <typename II>
00015 bool serialize (std::ostream& os, const II& b, const II& e, bool l = true) {
00016     if (l) if (!(os << (e - b) << '\n')) return false;
00017     for (II i = b; i != e; ++i)
00018         if (!(os << *i << ' ')) return false;
00019     if (b != e) os << '\n';
00020     return true;
00021 }
00022 
00023 template <typename II>
00024 bool unserialize (std::istream& is, const II& b, const II& e) {
00025     for (II i = b; i != e; ++i)
00026         if (!(is >> *i)) return false;
00027     return true;
00028 }
00029 
00030 namespace lemga {
00031 
00032 bool Pulse::serialize (std::ostream& os, ver_list& vl) const {
00033     SERIALIZE_PARENT(LearnModel, os, vl, 1);
00034     if (!(os << idx << ' ' << th.size() << ' ' << (dir? '-':'+') << '\n'))
00035         return false;
00036     return ::serialize(os, th.begin(), th.end(), false);
00037 }
00038 
00039 bool Pulse::unserialize (std::istream& is, ver_list& vl, const id_t& d) {
00040     if (d != id() && d != empty_id) return false;
00041     UNSERIALIZE_PARENT(LearnModel, is, vl, 1, v);
00042     assert(v > 0);
00043 
00044     UINT nl;
00045     if (!(is >> idx >> nl)) return false;
00046     char c;
00047     if (!(is >> c) || (c != '-' && c != '+')) return false;
00048     dir = (c == '-');
00049 
00050     th.resize(nl);
00051     return ::unserialize(is, th.begin(), th.end());
00052 }
00053 
00054 void Pulse::set_threshold (const std::vector<REAL>& t) {
00055     assert(t.size() <= max_l);
00056 #ifndef NDEBUG
00057     // assert t is sorted (std::is_sorted is an SGL extension)
00058     for (UINT i = 1; i < t.size(); ++i)
00059         assert(t[i-1] <= t[i]);
00060 #endif
00061     th = t;
00062 }
00063 
00064 REAL Pulse::train () {
00065     const UINT N = n_samples;
00066     assert(ptd != 0 && ptw != 0 && ptd->size() == N);
00067     assert(n_input() > 0 && n_output() == 1);
00068 
00069     std::vector<REAL> yw(N);
00070     for (UINT i = 0; i < N; ++i)
00071         yw[i] = ptd->y(i)[0] * (*ptw)[i];
00072 
00073     REAL minerr = 2;         // a number large enough (> 1)
00074     std::vector<UINT> thi;   // threshold index
00075     std::vector<REAL> xb(N); // backup for sorted x
00076 
00077     // no reallocation within loops
00078     std::vector<REAL> x(N), ysum(N);
00079     for (UINT d = 0; d < _n_in; ++d) {
00080         // extract the dimension d info, collapse data with same x
00081         std::map<REAL,REAL> xy;
00082         for (UINT i = 0; i < N; ++i)
00083             xy[ptd->x(i)[d]] += yw[i];
00084         REAL sum = 0;
00085         int last_sign = 0;  // 1: pos, 2: neg, 3: zero
00086         std::vector<REAL>::iterator px = x.begin(), py = ysum.begin();
00087         for (std::map<REAL,REAL>::const_iterator
00088                  p = xy.begin(); p != xy.end(); ++p) {
00089             static REAL last_x;
00090             const int cur_sign = (p->second > 0)? 1:((p->second < 0)? 2:3);
00091             if (last_sign != cur_sign && last_sign != 0) {
00095                 *px = last_x + p->first; *py = sum * 2;
00096                 ++px; ++py;
00097             }
00098             last_sign = cur_sign;
00099             last_x = p->first;
00100             sum += p->second;
00101             assert(-1.01 < sum && sum < 1.01);
00102         }
00103         *py = sum * 2;
00104         const UINT n = py - ysum.begin();
00105 
00106         std::vector<REAL> e0(n+1, 0); // error of pulses ending with -1
00107         std::vector<REAL> e1(n+1, 0); // error of pulses ending with +1
00108         std::vector<std::vector<UINT> > t0(n+1), t1(n+1); // transitions idx
00109 
00110         // dynamic programming: compute err for level 1--max_l
00111         // e0 and e1 at the begining of loop l are
00112         //    e0[i] = 2*best_e_{i,l} - 1 - sum(w*y),
00113         //    e1[i] = 2*best_e_{i,l} - 1 + sum(w*y).
00114         // where best_e_{i,l} is the lowest error if l transitions
00115         // happens before or at position i.
00116         for (UINT l = 0; l < max_l; ++l) {
00117             // swap e0 & e1, t0 & t1 to get rid of the sign change
00118             e0.swap(e1); t0.swap(t1);
00119 
00120             // compute errors for level (l+1)
00121             std::vector<REAL>::iterator p0 = e0.begin(), p1 = e1.begin();
00122             std::vector<REAL>::iterator ps = ysum.begin();
00123             for (UINT i = 0; i <= n; ++p0, ++p1, ++ps, ++i) {
00124                 *p0 -= *ps; *p1 += *ps;
00125             }
00126             assert(p0 == e0.end());
00127 
00128             std::vector<std::vector<UINT> >::iterator
00129                 pt0 = t0.begin(), pt1 = t1.begin();
00130             REAL bst0 = 3, bst1 = 3;  // a number large enough (> 2)
00131             p0 = e0.begin(); p1 = e1.begin();
00132             for (UINT i = 0; i <= n; ++p0, ++p1, ++pt0, ++pt1, ++i) {
00133                 static std::vector<UINT> tb0, tb1;  // always the best
00134                 assert(-2.01 < *p0 && *p0 < 2.01);
00135                 assert(-2.01 < *p1 && *p1 < 2.01);
00136 
00137                 if (*p0 < bst0) {
00138                     bst0 = *p0; tb0.swap(*pt0); // => tb0 = *pt0;
00139                     if (i < n) tb0.push_back(i);
00140                 }
00141                 *p0 = bst0; *pt0 = tb0;
00142 
00143                 if (*p1 < bst1) {
00144                     bst1 = *p1; tb1.swap(*pt1); // => tb1 = *pt1;
00145                     if (i < n) tb1.push_back(i);
00146                 }
00147                 *p1 = bst1; *pt1 = tb1;
00148             }
00149             assert(p0 == e0.end());
00150         }
00151 
00152         e0[n] += ysum[n] / 2;
00153         e1[n] -= ysum[n] / 2;
00154         if (e0[n] <= e1[n] && e0[n] < minerr) {
00155             minerr = e0[n]; idx = d; dir = !(max_l & 1);
00156             thi.swap(t0[n]); xb.swap(x);
00157         } else if (e1[n] < minerr) {
00158             minerr = e1[n]; idx = d; dir = (max_l & 1);
00159             thi.swap(t1[n]); xb.swap(x);
00160         }
00161     }
00162 
00163     th.clear();
00164     for (UINT i = 0; i < thi.size(); /* empty */) {
00165         UINT ind = thi[i]; ++i;
00166         if (i < thi.size() && ind == thi[i])
00167             ++i;
00168         else
00169             th.push_back(xb[ind] / 2);
00170     }
00171 
00172     return (1 + minerr) / 2;
00173 }
00174 
00175 Output Pulse::operator() (const Input& x) const {
00176     assert(idx < n_input() && x.size() == n_input());
00177 
00178     if (th.empty())
00179         return Output(1, dir? -1 : 1);
00180 
00181     const UINT i =
00182         std::lower_bound(th.begin(), th.end(), x[idx]) - th.begin();
00183     return Output(1, ((i & 1) ^ dir)? -1 : 1);
00184 }
00185 
00186 } // namespace lemga

Generated on Mon Jan 9 23:43:24 2006 for LEMGA by  doxygen 1.4.6