00001
00005 #include <assert.h>
00006 #include <cmath>
00007 #include <vector>
00008 #include <map>
00009 #include <algorithm>
00010 #include "pulse.h"
00011
00012 REGISTER_CREATOR(lemga::Pulse);
00013
00014 template <typename II>
00015 bool serialize (std::ostream& os, const II& b, const II& e, bool l = true) {
00016 if (l) if (!(os << (e - b) << '\n')) return false;
00017 for (II i = b; i != e; ++i)
00018 if (!(os << *i << ' ')) return false;
00019 if (b != e) os << '\n';
00020 return true;
00021 }
00022
00023 template <typename II>
00024 bool unserialize (std::istream& is, const II& b, const II& e) {
00025 for (II i = b; i != e; ++i)
00026 if (!(is >> *i)) return false;
00027 return true;
00028 }
00029
00030 namespace lemga {
00031
00032 bool Pulse::serialize (std::ostream& os, ver_list& vl) const {
00033 SERIALIZE_PARENT(LearnModel, os, vl, 1);
00034 if (!(os << idx << ' ' << th.size() << ' ' << (dir? '-':'+') << '\n'))
00035 return false;
00036 return ::serialize(os, th.begin(), th.end(), false);
00037 }
00038
00039 bool Pulse::unserialize (std::istream& is, ver_list& vl, const id_t& d) {
00040 if (d != id() && d != empty_id) return false;
00041 UNSERIALIZE_PARENT(LearnModel, is, vl, 1, v);
00042 assert(v > 0);
00043
00044 UINT nl;
00045 if (!(is >> idx >> nl)) return false;
00046 char c;
00047 if (!(is >> c) || (c != '-' && c != '+')) return false;
00048 dir = (c == '-');
00049
00050 th.resize(nl);
00051 return ::unserialize(is, th.begin(), th.end());
00052 }
00053
00054 void Pulse::set_threshold (const std::vector<REAL>& t) {
00055 assert(t.size() <= max_l);
00056 #ifndef NDEBUG
00057
00058 for (UINT i = 1; i < t.size(); ++i)
00059 assert(t[i-1] <= t[i]);
00060 #endif
00061 th = t;
00062 }
00063
00064 REAL Pulse::train () {
00065 const UINT N = n_samples;
00066 assert(ptd != 0 && ptw != 0 && ptd->size() == N);
00067 assert(n_input() > 0 && n_output() == 1);
00068
00069 std::vector<REAL> yw(N);
00070 for (UINT i = 0; i < N; ++i)
00071 yw[i] = ptd->y(i)[0] * (*ptw)[i];
00072
00073 REAL minerr = 2;
00074 std::vector<UINT> thi;
00075 std::vector<REAL> xb(N);
00076
00077
00078 std::vector<REAL> x(N), ysum(N);
00079 for (UINT d = 0; d < _n_in; ++d) {
00080
00081 std::map<REAL,REAL> xy;
00082 for (UINT i = 0; i < N; ++i)
00083 xy[ptd->x(i)[d]] += yw[i];
00084 REAL sum = 0;
00085 int last_sign = 0;
00086 std::vector<REAL>::iterator px = x.begin(), py = ysum.begin();
00087 for (std::map<REAL,REAL>::const_iterator
00088 p = xy.begin(); p != xy.end(); ++p) {
00089 static REAL last_x;
00090 const int cur_sign = (p->second > 0)? 1:((p->second < 0)? 2:3);
00091 if (last_sign != cur_sign && last_sign != 0) {
00095 *px = last_x + p->first; *py = sum * 2;
00096 ++px; ++py;
00097 }
00098 last_sign = cur_sign;
00099 last_x = p->first;
00100 sum += p->second;
00101 assert(-1.01 < sum && sum < 1.01);
00102 }
00103 *py = sum * 2;
00104 const UINT n = py - ysum.begin();
00105
00106 std::vector<REAL> e0(n+1, 0);
00107 std::vector<REAL> e1(n+1, 0);
00108 std::vector<std::vector<UINT> > t0(n+1), t1(n+1);
00109
00110
00111
00112
00113
00114
00115
00116 for (UINT l = 0; l < max_l; ++l) {
00117
00118 e0.swap(e1); t0.swap(t1);
00119
00120
00121 std::vector<REAL>::iterator p0 = e0.begin(), p1 = e1.begin();
00122 std::vector<REAL>::iterator ps = ysum.begin();
00123 for (UINT i = 0; i <= n; ++p0, ++p1, ++ps, ++i) {
00124 *p0 -= *ps; *p1 += *ps;
00125 }
00126 assert(p0 == e0.end());
00127
00128 std::vector<std::vector<UINT> >::iterator
00129 pt0 = t0.begin(), pt1 = t1.begin();
00130 REAL bst0 = 3, bst1 = 3;
00131 p0 = e0.begin(); p1 = e1.begin();
00132 for (UINT i = 0; i <= n; ++p0, ++p1, ++pt0, ++pt1, ++i) {
00133 static std::vector<UINT> tb0, tb1;
00134 assert(-2.01 < *p0 && *p0 < 2.01);
00135 assert(-2.01 < *p1 && *p1 < 2.01);
00136
00137 if (*p0 < bst0) {
00138 bst0 = *p0; tb0.swap(*pt0);
00139 if (i < n) tb0.push_back(i);
00140 }
00141 *p0 = bst0; *pt0 = tb0;
00142
00143 if (*p1 < bst1) {
00144 bst1 = *p1; tb1.swap(*pt1);
00145 if (i < n) tb1.push_back(i);
00146 }
00147 *p1 = bst1; *pt1 = tb1;
00148 }
00149 assert(p0 == e0.end());
00150 }
00151
00152 e0[n] += ysum[n] / 2;
00153 e1[n] -= ysum[n] / 2;
00154 if (e0[n] <= e1[n] && e0[n] < minerr) {
00155 minerr = e0[n]; idx = d; dir = !(max_l & 1);
00156 thi.swap(t0[n]); xb.swap(x);
00157 } else if (e1[n] < minerr) {
00158 minerr = e1[n]; idx = d; dir = (max_l & 1);
00159 thi.swap(t1[n]); xb.swap(x);
00160 }
00161 }
00162
00163 th.clear();
00164 for (UINT i = 0; i < thi.size(); ) {
00165 UINT ind = thi[i]; ++i;
00166 if (i < thi.size() && ind == thi[i])
00167 ++i;
00168 else
00169 th.push_back(xb[ind] / 2);
00170 }
00171
00172 return (1 + minerr) / 2;
00173 }
00174
00175 Output Pulse::operator() (const Input& x) const {
00176 assert(idx < n_input() && x.size() == n_input());
00177
00178 if (th.empty())
00179 return Output(1, dir? -1 : 1);
00180
00181 const UINT i =
00182 std::lower_bound(th.begin(), th.end(), x[idx]) - th.begin();
00183 return Output(1, ((i & 1) ^ dir)? -1 : 1);
00184 }
00185
00186 }