00001
00005 #include <assert.h>
00006 #include "aggregating.h"
00007
00008 namespace lemga {
00009
00015 void Aggregating::clear () {
00016 lm.clear();
00017 n_in_agg = 0;
00018 }
00019
00021 Aggregating::Aggregating ()
00022 : LearnModel(0,0), lm_base(), n_in_agg(0), max_n_model(0)
00023 { }
00024
00029 Aggregating::Aggregating (const Aggregating& a)
00030 : LearnModel(a), n_in_agg(a.n_in_agg), max_n_model(a.max_n_model)
00031 {
00032 lm_base = a.lm_base->clone();
00033
00034 const UINT lms = a.lm.size();
00035 assert(n_in_agg <= lms);
00036 for (UINT i = 0; i < lms; ++i)
00037 lm.push_back(a.lm[i]->clone());
00038 }
00039
00041 const Aggregating& Aggregating::operator= (const Aggregating& a) {
00042 if (&a == this) return *this;
00043
00044 clear();
00045 LearnModel::operator=(a);
00046
00047 lm_base = a.lm_base->clone();
00048 n_in_agg = a.n_in_agg;
00049 max_n_model = a.max_n_model;
00050
00051 const UINT lms = a.lm.size();
00052 assert(n_in_agg <= lms);
00053 for (UINT i = 0; i < lms; ++i)
00054 lm.push_back(a.lm[i]->clone());
00055
00056 return *this;
00057 }
00058
00059 bool Aggregating::serialize (std::ostream& os, ver_list& vl) const {
00060 SERIALIZE_PARENT(LearnModel, os, vl, 1);
00061
00062 if (!(os << lm.size() << ' ' << (lm_base != 0) << '\n'))
00063 return false;
00064 if (lm_base != 0)
00065 if (!(os << *lm_base)) return false;
00066 for (UINT i = 0; i < lm.size(); ++i)
00067 if (!(os << *lm[i])) return false;
00068
00069 return true;
00070 }
00071
00072 bool Aggregating::unserialize (std::istream& is, ver_list& vl, const id_t& d) {
00073 if (d != id() && d != empty_id) return false;
00074 UNSERIALIZE_PARENT(LearnModel, is, vl, 1, v);
00075
00076 if (v == 0)
00077 if (!(is >> _n_in >> _n_out)) return false;
00078
00079 UINT t3, t4;
00080 if (!(is >> t3 >> t4) || t4 > 1) return false;
00081
00082 clear();
00083
00084 if (!t4) lm_base = 0;
00085 else {
00086 if (v == 0) {
00087 char c; is >> c;
00088 assert(c == '#');
00089 is.ignore(100, '\n');
00090 }
00091 LearnModel* p = (LearnModel*) Object::create(is);
00092 if (p == 0)
00093 return false;
00094 lm_base = p;
00095 }
00096
00097 for (UINT i = 0; i < t3; ++i) {
00098 LearnModel* p = (LearnModel*) Object::create(is);
00099 if (p == 0 || p->n_input() != _n_in || p->n_output() != _n_out)
00100 return false;
00101 lm.push_back(p);
00102 }
00103 n_in_agg = t3;
00104
00105 return true;
00106 }
00107
00111 void Aggregating::set_base_model (const LearnModel& blm) {
00112 lm_base = blm.clone();
00113 if (!_n_in) _n_in = lm_base->n_input();
00114 if (!_n_out) _n_out = lm_base->n_output();
00115 assert((blm.n_input() == n_input() || !blm.n_input()) &&
00116 (blm.n_output() == n_output() || !blm.n_output()));
00117 }
00118
00126 bool Aggregating::set_aggregation_size (UINT n) {
00127 if (n <= size()) {
00128 n_in_agg = n;
00129 return true;
00130 }
00131 else return false;
00132 }
00133
00134 void Aggregating::initialize () {
00135 clear();
00136 assert(lm_base != NULL);
00137 lm_base->initialize();
00138 }
00139
00140 void Aggregating::set_train_data (const pDataSet& pd, const pDataWgt& pw) {
00141 LearnModel::set_train_data(pd, pw);
00142 for (UINT i = 0; i < lm.size(); ++i)
00143 if (lm[i] != 0)
00144 lm[i]->set_train_data(ptd, ptw);
00145 }
00146
00147 }