ordinal_ble.cpp

Go to the documentation of this file.
00001 
00005 #include <cmath>
00006 #include "ordinal_ble.h"
00007 
00008 REGISTER_CREATOR(lemga::Ordinal_BLE);
00009 
00010 namespace lemga {
00011 
00012 #define OUT2RANK(y)  ((UINT) ((y) - .5))
00013 #define RANK2OUT(r)  ((r) + 1)
00014 #define VALIDRANK(y) ((y) > .9 && std::fabs((y)-1-OUT2RANK(y)) < INFINITESIMAL)
00015 #define nrank  (out_tab.size())
00016 #define n_hyp  (ext_tab.size())
00017 
00018 bool Ordinal_BLE::serialize (std::ostream& os, ver_list& vl) const {
00019     SERIALIZE_PARENT(LearnModel, os, vl, 1);
00020 
00021     if (!(os << nrank << ' ' << n_hyp << ' ' << n_ext << ' '
00022              << full_ext << ' ' << (lm != 0) << '\n'))
00023         return false;
00024 
00025     for (UINT i = 0; i < nrank; ++i) {
00026         assert(out_tab[i].size() == n_hyp);
00027         for (UINT j = 0; j < n_hyp; ++j)
00028             os << out_tab[i][j] << ' ';
00029         if (n_hyp) os << '\n';
00030     }
00031     for (UINT i = 0; i < n_hyp; ++i) {
00032         assert(ext_tab[i].size() == n_ext);
00033         for (UINT j = 0; j < n_ext; ++j)
00034             os << ext_tab[i][j] << ' ';
00035         if (n_ext) os << '\n';
00036     }
00037 
00038     if (lm != 0) {
00039         assert(_n_in == 0 || lm->valid_dimensions(_n_in+n_ext, 1));
00040         return (os << *lm);
00041     }
00042     return os;
00043 }
00044 
00045 bool Ordinal_BLE::unserialize (std::istream& is, ver_list& vl, const id_t& d) {
00046     if (d != id() && d != NIL_ID) return false;
00047     UNSERIALIZE_PARENT(LearnModel, is, vl, 1, v);
00048     assert(v > 0);
00049 
00050     lm = 0; ptd = 0; ptw = 0; ext_d = 0; ext_w = 0;
00051     d_nrank = 0; reset_data = true;
00052 
00053     UINT nr, nh, fe, tr;
00054     if (!(is >> nr >> nh >> n_ext >> fe >> tr) || fe > 1 || tr > 1)
00055         return false;
00056     full_ext = fe;
00057 
00058     out_tab.resize(nr); ext_tab.resize(nh);
00059     for (UINT i = 0; i < nrank; ++i) {
00060         out_tab[i].resize(n_hyp);
00061         for (UINT j = 0; j < n_hyp; ++j)
00062             is >> out_tab[i][j];
00063     }
00064     for (UINT i = 0; i < n_hyp; ++i) {
00065         ext_tab[i].resize(n_ext);
00066         for (UINT j = 0; j < n_ext; ++j)
00067             is >> ext_tab[i][j];
00068     }
00069 
00070     if (tr) {
00071         lm = (LearnModel*) Object::create(is);
00072         if (lm == 0 || !(_n_in == 0 || lm->valid_dimensions(_n_in+n_ext, 1)))
00073             return false;
00074     }
00075     return true;
00076 }
00077 
00078 Ordinal_BLE::Ordinal_BLE (const Ordinal_BLE& o)
00079     : LearnModel(o), lm(0), full_ext(o.full_ext),
00080       out_tab(o.out_tab), ext_tab(o.ext_tab), n_ext(o.n_ext),
00081       ext_d(o.ext_d), ext_w(o.ext_w), d_nrank(o.d_nrank),
00082       reset_data(o.reset_data)
00083 {
00084     if (o.lm != 0) lm = o.lm->clone();
00085 }
00086 
00087 const Ordinal_BLE& Ordinal_BLE::operator= (const Ordinal_BLE& o) {
00088     if (&o == this) return *this;
00089 
00090     LearnModel::operator=(o);
00091     lm = 0; full_ext = o.full_ext;
00092     out_tab = o.out_tab; ext_tab = o.ext_tab; n_ext = o.n_ext;
00093     ext_d = o.ext_d; ext_w = o.ext_w; d_nrank = o.d_nrank;
00094     reset_data = o.reset_data;
00095     if (o.lm != 0) lm = o.lm->clone();
00096 
00097     return *this;
00098 }
00099 
00100 void Ordinal_BLE::set_model (const LearnModel& l) {
00101     lm = l.clone();
00102     reset_data = true;
00103 }
00104 
00105 void Ordinal_BLE::set_full_extension (bool f) {
00106     assert(f); //??? only deal with full-extension for now
00107     if (full_ext ^ f) { // full_ext will be changed
00108         ext_d = 0; ext_w = 0;
00109     }
00110     full_ext = f;
00111 }
00112 
00113 void Ordinal_BLE::set_tables (const ECOC_TABLE& ecc, const EXT_TABLE& ext) {
00114     out_tab = ecc; ext_tab = ext;
00115     assert(nrank > 1 && n_hyp > 0);
00116     n_ext = ext_tab[0].size();
00117 #ifndef NDEBUG
00118     for (UINT i = 0; i < nrank; ++i)
00119         assert(out_tab[i].size() == n_hyp);
00120     for (UINT i = 0; i < n_hyp; ++i)
00121         assert(ext_tab[i].size() == n_ext);
00122 #endif
00123     local_d.resize(nrank);
00124 }
00125 
00126 void Ordinal_BLE::set_tables (BLE_TYPE bt, UINT nr) {
00127     ECOC_TABLE ecc(nr);
00128     EXT_TABLE ext;
00129 
00130     switch (bt) {
00131     case MULTI_THRESHOLD:
00132         assert(nr > 1);
00133         for (UINT i = 0; i < nr; ++i) {
00134             ecc[i].resize(nr-1, -1);
00135             for (UINT j = 0; j < i; ++j)
00136                 ecc[i][j] = 1;
00137         }
00138         ext.resize(nr-1);
00139         for (UINT i = 0; i < nr-1; ++i) {
00140             ext[i].resize(nr-1, 0);
00141             ext[i][i] = 1;
00142         }
00143         break;
00144 
00145     default:
00146         assert(false);
00147     }
00148 
00149     set_tables(ecc, ext);
00150 }
00151 
00152 REAL Ordinal_BLE::c_error (const Output& out, const Output& y) const {
00153     assert(n_output() == 1 && VALIDRANK(out[0]) && VALIDRANK(y[0]));
00154     return OUT2RANK(out[0]) != OUT2RANK(y[0]);
00155 }
00156 
00157 REAL Ordinal_BLE::r_error (const Output& out, const Output& y) const {
00158     assert(n_output() == 1 && VALIDRANK(out[0]) && VALIDRANK(y[0]));
00159     return std::fabs(out[0] - y[0]);
00160 }
00161 
00162 void Ordinal_BLE::set_train_data (const pDataSet& pd, const pDataWgt& pw) {
00163     pDataSet old_ptd = ptd;
00164     LearnModel::set_train_data(pd, pw);
00165     if (old_ptd == ptd) return;
00166 
00167     ext_d = 0; ext_w = 0;
00168     UINT old_nr = d_nrank;
00169 
00170     // let's be sure that the labels are 1-K (nrank)
00171     std::vector<bool> has_example;
00172     UINT nr = 0;
00173     for (UINT i = 0; i < n_samples; ++i) {
00174         REAL y = ptd->y(i)[0];
00175         if (!VALIDRANK(y)) {
00176             std::cerr << "Ordinal_BLE: Error: "
00177                       << "Label (" << y << ") is not a valid rank.\n";
00178             std::exit(-1);
00179         }
00180         UINT r = OUT2RANK(y);
00181         if (r >= has_example.size())
00182             has_example.resize(r+1, false);
00183         nr += !has_example[r];
00184         has_example[r] = true;
00185     }
00186     d_nrank = has_example.size();
00187     if (nr < d_nrank) {
00188         std::cerr << "Ordinal_BLE: Warning: " << "Missing rank(s) ";
00189         for (UINT r = 0; r < d_nrank; ++r)
00190             if (!has_example[r]) {
00191                 std::cerr << RANK2OUT(r);
00192                 if (++nr < d_nrank) std::cerr << ", ";
00193             }
00194         std::cerr << ".\n";
00195     }
00196 
00197     if (old_nr > 0 && old_nr != d_nrank)
00198         std::cerr << "Ordinal_BLE: Warning: "
00199                   << "Number of ranks changed from " << old_nr
00200                   << " to " << d_nrank << ".\n";
00201 }
00202 
00203 void Ordinal_BLE::extend_data () {
00204     //??? full extension only
00205     assert(n_hyp > 0 && ptd != 0 && n_samples > 0);
00206 
00207     DataSet* rd = new DataSet;
00208     DataWgt* rw = new DataWgt;
00209     rw->reserve(n_samples * n_hyp);
00210 
00211     // don't assume _n_in has been set
00212     UINT nin = ptd->x(0).size();
00213     Input rx(nin + n_ext);
00214     for (UINT i = 0; i < n_samples; ++i) {
00215         const Input& x = ptd->x(i);
00216         const UINT r = OUT2RANK(ptd->y(i)[0]);
00217         assert(x.size() == nin && r < nrank);
00218         const REAL wgt = (*ptw)[i] / n_hyp;
00219 
00220         for (UINT j = 0; j < n_hyp; ++j) {
00221             REAL ry;
00222             extend_example(x, r, j, rx, ry);
00223             rd->append(rx, Output(1, ry));
00224             rw->push_back(wgt);
00225         }
00226     }
00227 
00228     ext_d = rd; ext_w = rw;
00229     reset_data = true;
00230 }
00231 
00232 void Ordinal_BLE::train () {
00233     assert(ptd != 0 && ptw != 0);
00234     assert(lm != 0);
00235     if (nrank == 0) // set the default tables
00236         set_tables(BLE_DEFAULT, d_nrank);
00237 
00238     assert(nrank > 0 && n_hyp > 0);
00239     if (d_nrank > nrank) {
00240         std::cerr << "Ordinal_BLE: Error: "
00241                   << "More ranks in the data than in the ECC matrix.\n";
00242         std::exit(-1);
00243     } else if (d_nrank < nrank)
00244         std::cerr << "Ordinal_BLE: Warning: "
00245                   << "Less ranks in the data than in the ECC matrix.\n";
00246 
00247     set_dimensions(*ptd);
00248     if (ext_d == 0) extend_data();
00249     assert(ext_d != 0 && ext_w != 0);
00250 
00251     if (reset_data)
00252         lm->set_train_data(ext_d, ext_w);
00253     reset_data = false;
00254     lm->train();
00255 }
00256 
00257 void Ordinal_BLE::reset () {
00258     LearnModel::reset();
00259     if (lm != 0) lm->reset();
00260 }
00261 
00262 #define GET_BEST_RANK(distance_to_rank_r)       \
00263     REAL dmin = INFINITY; UINT rmin = UINT(-1); \
00264     for (UINT r = 0; r < nrank; ++r) {          \
00265         REAL dr = distance_to_rank_r;           \
00266         assert(dr < INFINITY/10);               \
00267         if (dr < dmin) { dmin = dr; rmin = r; } \
00268     }
00269 
00270 Output Ordinal_BLE::operator() (const Input& x) const {
00271     assert(valid_dimensions(x.size(), 1));
00272     const std::vector<REAL> d = distances(x);
00273     GET_BEST_RANK(d[r]);
00274     return Output(1, RANK2OUT(rmin));
00275 }
00276 
00278 void Ordinal_BLE::extend_input (const Input& x, UINT t, Input& ext_x) const {
00279     UINT n_in = x.size();
00280     assert(t < n_hyp && ext_tab[t].size() == n_ext);
00281     assert(ext_x.size() == n_in + n_ext);
00282     //ext_x.resize(n_in + n_ext);
00283     std::copy(x.begin(), x.end(), ext_x.begin());
00284     std::copy(ext_tab[t].begin(), ext_tab[t].end(), ext_x.begin()+n_in);
00285 }
00286 
00287 void Ordinal_BLE::extend_example (const Input& x, UINT r, UINT t,
00288                                   Input& ext_x, REAL& ext_y) const
00289 {
00290     assert(r < nrank && out_tab[r].size() == n_hyp);
00291     extend_input(x, t, ext_x);
00292     ext_y = out_tab[r][t];
00293 }
00294 
00295 REAL Ordinal_BLE::ECOC_distance (const Output& o,
00296                                  const ECOC_VECTOR& cw) const {
00297     assert(o.size() == n_hyp && n_hyp <= cw.size());
00298     REAL d = 0;
00299     for (UINT i = 0; i < n_hyp; ++i)
00300         d += std::exp(- o[i] * cw[i]);
00301     return d;
00302 }
00303 
00304 const std::vector<REAL>& Ordinal_BLE::distances (const Input& x) const {
00305     UINT nin = x.size();
00306     assert(lm != 0 && lm->exact_dimensions(nin+n_ext, 1));
00307     Input rx(nin + n_ext);
00308     Output out(n_hyp);
00309     for (UINT j = 0; j < n_hyp; ++j) {
00310         extend_input(x, j, rx);
00311         out[j] = (*lm)(rx)[0];
00312     }
00313     std::vector<REAL>& d = local_d;
00314     assert(local_d.size() == nrank);
00315     for (UINT i = 0; i < nrank; ++i)
00316         d[i] = ECOC_distance(out, out_tab[i]);
00317     return d;
00318 }
00319 
00320 } // namespace lemga

Generated on Wed Nov 8 08:15:22 2006 for LEMGA by  doxygen 1.4.6