aggregating.cpp

Go to the documentation of this file.
00001 
00005 #include <assert.h>
00006 #include "aggregating.h"
00007 
00008 namespace lemga {
00009 
00015 void Aggregating::reset () {
00016     LearnModel::reset();
00017     lm.clear(); n_in_agg = 0;
00018     assert(lm_base == 0 || valid_dimensions(*lm_base));
00019 }
00020 
00025 Aggregating::Aggregating (const Aggregating& a)
00026     : LearnModel(a), lm_base(a.lm_base),
00027       n_in_agg(a.n_in_agg), max_n_model(a.max_n_model)
00028 {
00029     const UINT lms = a.lm.size();
00030     assert(n_in_agg <= lms);
00031     for (UINT i = 0; i < lms; ++i)
00032         lm.push_back(a.lm[i]->clone());
00033 }
00034 
00036 const Aggregating& Aggregating::operator= (const Aggregating& a) {
00037     if (&a == this) return *this;
00038 
00039     LearnModel::operator=(a);
00040     lm_base = a.lm_base;
00041     n_in_agg = a.n_in_agg;
00042     max_n_model = a.max_n_model;
00043 
00044     const UINT lms = a.lm.size();
00045     assert(n_in_agg <= lms);
00046     lm.clear();
00047     for (UINT i = 0; i < lms; ++i)
00048         lm.push_back(a.lm[i]->clone());
00049 
00050     return *this;
00051 }
00052 
00053 bool Aggregating::serialize (std::ostream& os, ver_list& vl) const {
00054     SERIALIZE_PARENT(LearnModel, os, vl, 1);
00055 
00056     if (!(os << lm.size() << ' ' << (lm_base != 0) << '\n'))
00057         return false;
00058     if (lm_base != 0)
00059         if (!(os << *lm_base)) return false;
00060     for (UINT i = 0; i < lm.size(); ++i)
00061         if (!(os << *lm[i])) return false;
00062 
00063     return true;
00064 }
00065 
00066 bool Aggregating::unserialize (std::istream& is, ver_list& vl, const id_t& d) {
00067     if (d != id() && d != NIL_ID) return false;
00068     UNSERIALIZE_PARENT(LearnModel, is, vl, 1, v);
00069 
00070     if (v == 0) /* Take care of _n_in and _n_out */
00071         if (!(is >> _n_in >> _n_out)) return false;
00072 
00073     UINT t3, t4;
00074     if (!(is >> t3 >> t4) || t4 > 1) return false;
00075 
00076     if (!t4) lm_base = 0;
00077     else {
00078         if (v == 0) { /* ignore a one-line comment */
00079             char c; is >> c;
00080             assert(c == '#');
00081             is.ignore(100, '\n');
00082         }
00083         LearnModel* p = (LearnModel*) Object::create(is);
00084         lm_base = p;
00085         if (p == 0 || !valid_dimensions(*p)) return false;
00086     }
00087 
00088     lm.clear();
00089     for (UINT i = 0; i < t3; ++i) {
00090         LearnModel* p = (LearnModel*) Object::create(is);
00091         lm.push_back(p);
00092         if (p == 0 || !exact_dimensions(*p)) return false;
00093     }
00094     n_in_agg = t3;
00095 
00096     return true;
00097 }
00098 
00102 void Aggregating::set_base_model (const LearnModel& blm) {
00103     assert(valid_dimensions(blm));
00104     lm_base = blm.clone();
00105 }
00106 
00114 bool Aggregating::set_aggregation_size (UINT n) {
00115     if (n <= size()) {
00116         n_in_agg = n;
00117         return true;
00118     }
00119     return false;
00120 }
00121 
00122 void Aggregating::set_train_data (const pDataSet& pd, const pDataWgt& pw) {
00123     LearnModel::set_train_data(pd, pw);
00124     // Note: leave the compatibility check of the base learner to training.
00125     for (UINT i = 0; i < lm.size(); ++i)
00126         if (lm[i] != 0)
00127             lm[i]->set_train_data(ptd, ptw);
00128 }
00129 
00130 }

Generated on Wed Nov 8 08:15:20 2006 for LEMGA by  doxygen 1.4.6