00001
00005 #include <assert.h>
00006 #include <algorithm>
00007 #include <cmath>
00008 #include "learnmodel.h"
00009 #include "datafeeder.h"
00010
00011 namespace lemga {
00012
00013 DataFeeder::DataFeeder (const pDataSet& pd)
00014 : dat(pd), perms(0), _do_normalize(MIN_MAX) {
00015 fsize = dat->size();
00016 }
00017
00018 DataFeeder::DataFeeder (std::istream& ds)
00019 : perms(0), _do_normalize(MIN_MAX) {
00020
00021 dat = load_data(ds, (1L<<30)-1);
00022 fsize = dat->size();
00023 }
00024
00025 void DataFeeder::set_train_size (UINT trn) {
00026 assert(trn < fsize);
00027 train_size = trn;
00028 }
00029
00030 bool DataFeeder::next_train_test (pDataSet& ptr, pDataSet& pte) const {
00031 DataSet *p_tr = new DataSet();
00032 DataSet *p_te = new DataSet();
00033
00034 std::vector<UINT> perm;
00035 if (!next_permutation(perm)) return false;
00036
00037 for (UINT i = 0; i < train_size; ++i)
00038 p_tr->append(dat->x(perm[i]), dat->y(perm[i]));
00039 for (UINT i = train_size; i < fsize; ++i)
00040 p_te->append(dat->x(perm[i]), dat->y(perm[i]));
00041
00042 if (_do_normalize != NONE) {
00043 LINEAR_SCALE_PARAMS lsp;
00044 switch (_do_normalize) {
00045 case MIN_MAX: lsp = min_max(*p_tr); break;
00046 case MEAN_VAR: lsp = mean_var(*p_tr); break;
00047 default: assert(false);
00048 }
00049 linear_scale(*p_tr, lsp);
00050 linear_scale(*p_te, lsp);
00051 }
00052
00053 ptr = p_tr; pte = p_te;
00054 return true;
00055 }
00056
00057 bool DataFeeder::next_permutation (std::vector<UINT>& perm) const {
00058 perm.resize(fsize);
00059
00060 if (perms == 0) {
00061 for (UINT i = 0; i < fsize; ++i)
00062 perm[i] = i;
00063 std::random_shuffle(perm.begin(), perm.end());
00064 return true;
00065 }
00066
00067 std::vector<bool> visited(fsize, false);
00068 for (UINT i = 0; i < fsize; ++i) {
00069 UINT idx;
00070 if (!((*perms) >> idx)) {
00071 if (i) std::cerr << "DataFeeder: "
00072 "Permutation stream ends prematurely\n";
00073 return false;
00074 }
00075 if (idx >= fsize || visited[idx]) {
00076 std::cerr << "DataFeeder: "
00077 "Permutation stream has errors\n";
00078 return false;
00079 }
00080 visited[idx] = true;
00081 perm[i] = idx;
00082 }
00083 return true;
00084 }
00085
00086 DataFeeder::LINEAR_SCALE_PARAMS DataFeeder::min_max (DataSet& d) {
00087 assert(d.size() > 0);
00088
00089 const Input& x0 = d.x(0);
00090 const UINT ls = x0.size();
00091 std::vector<REAL> dmin(x0), dmax(x0);
00092 for (UINT i = 1; i < d.size(); ++i) {
00093 const Input& x = d.x(i);
00094 for (UINT j = 0; j < ls; ++j) {
00095 if (dmin[j] > x[j])
00096 dmin[j] = x[j];
00097 else if (dmax[j] < x[j])
00098 dmax[j] = x[j];
00099 }
00100 }
00101
00102 LINEAR_SCALE_PARAMS l(ls);
00103 for (UINT j = 0; j < ls; ++j) {
00104 l[j].center = (dmin[j] + dmax[j]) / 2;
00105 if (dmin[j] != dmax[j])
00106 l[j].scale = 2 / (dmax[j] - dmin[j]);
00107 else
00108 l[j].scale = 0;
00109 }
00110 return l;
00111 }
00112
00113 DataFeeder::LINEAR_SCALE_PARAMS DataFeeder::mean_var (DataSet& d) {
00114 const UINT n = d.size();
00115 assert(n > 0);
00116 const UINT ls = d.x(0).size();
00117
00118 std::vector<REAL> sum1(ls, 0), sum2(ls, 0);
00119 for (UINT i = 0; i < n; ++i) {
00120 const Input& x = d.x(i);
00121 for (UINT j = 0; j < ls; ++j) {
00122 sum1[j] += x[j];
00123 sum2[j] += x[j] * x[j];
00124 }
00125 }
00126
00127 LINEAR_SCALE_PARAMS l(ls);
00128 for (UINT j = 0; j < ls; ++j) {
00129 l[j].center = sum1[j] / n;
00130 REAL n_1_var = sum2[j] - sum1[j] * l[j].center;
00131 if (n_1_var > INFINITESIMAL)
00132 l[j].scale = std::sqrt((n-1) / n_1_var);
00133 else
00134 l[j].scale = 0;
00135 }
00136 return l;
00137 }
00138
00139 void DataFeeder::linear_scale (DataSet& d, const LINEAR_SCALE_PARAMS& l) {
00140 const UINT ls = l.size();
00141 for (UINT i = 0; i < d.size(); ++i) {
00142 Input x = d.x(i);
00143 assert(x.size() == ls);
00144 for (UINT j = 0; j < ls; ++j)
00145 x[j] = (x[j] - l[j].center) * l[j].scale;
00146 d.replace(i, x, d.y(i));
00147 }
00148 }
00149
00150 }