00001
00002 #ifndef __LEMGA_LEARNMODEL_H__
00003 #define __LEMGA_LEARNMODEL_H__
00004
00013 #include <assert.h>
00014 #include <vector>
00015 #include "object.h"
00016 #include "dataset.h"
00017 #include "shared_ptr.h"
00018
00019 #define VERBOSE_OUTPUT 1
00020
00021 namespace lemga {
00022
00023 typedef std::vector<REAL> Input;
00024 typedef std::vector<REAL> Output;
00025
00026 typedef dataset<Input,Output> DataSet;
00027 typedef std::vector<REAL> DataWgt;
00028 typedef const_shared_ptr<DataSet> pDataSet;
00029 typedef const_shared_ptr<DataWgt> pDataWgt;
00030
00032 DataSet* load_data (std::istream&, UINT, UINT, UINT);
00033 DataSet* load_data (std::istream&, UINT);
00034
00064 class LearnModel : public Object {
00065 protected:
00066 UINT _n_in;
00067 UINT _n_out;
00068 pDataSet ptd;
00069 pDataWgt ptw;
00070 UINT n_samples;
00071
00072 FILE* logf;
00073
00074 public:
00075 LearnModel (UINT n_in = 0, UINT n_out = 0);
00076
00078 virtual LearnModel* create () const = 0;
00079 virtual LearnModel* clone () const = 0;
00080
00081 UINT n_input () const { return _n_in; }
00082 UINT n_output () const { return _n_out; }
00083
00084 void set_log_file (FILE* f) { logf = f; }
00086
00088
00094 virtual bool support_weighted_data () const { return false; }
00095
00097 virtual REAL r_error (const Output& out, const Output& y) const;
00099 virtual REAL c_error (const Output& out, const Output& y) const;
00100
00102 REAL train_r_error () const;
00104 REAL train_c_error () const;
00106 REAL test_r_error (const pDataSet&) const;
00108 REAL test_c_error (const pDataSet&) const;
00109
00110 virtual void initialize () {
00111 std::cerr << "!!! initialize() is depreciated.\n"
00112 << "!!! See the documentation of LearnModel and reset().\n";
00113 }
00114
00116 virtual void set_train_data (const pDataSet&, const pDataWgt& = 0);
00118 const pDataSet& train_data () const { return ptd; }
00119
00120
00121
00122
00123
00126 virtual void train () = 0;
00127
00131 virtual void reset ();
00133
00134 virtual Output operator() (const Input&) const = 0;
00135
00139 virtual Output get_output (UINT idx) const {
00140 assert(ptw != 0);
00141 return operator()(ptd->x(idx)); }
00142
00144
00158 virtual REAL margin_norm () const { return 1; }
00160 virtual REAL margin_of (const Input& x, const Output& y) const;
00164 virtual REAL margin (UINT i) const {
00165 assert(ptw != 0);
00166 return margin_of(ptd->x(i), ptd->y(i)); }
00168 REAL min_margin () const;
00170
00171 bool valid_dimensions (UINT, UINT) const;
00172 inline bool valid_dimensions (const LearnModel& l) const {
00173 return valid_dimensions(l.n_input(), l.n_output()); }
00174
00175 inline bool exact_dimensions (UINT i, UINT o) const {
00176 return (i > 0 && o > 0 && valid_dimensions(i, o)); }
00177 inline bool exact_dimensions (const LearnModel& l) const {
00178 return exact_dimensions(l.n_input(), l.n_output()); }
00179 inline bool exact_dimensions (const DataSet& d) const {
00180 assert(d.size() > 0);
00181 return exact_dimensions(d.x(0).size(), d.y(0).size()); }
00182
00183 protected:
00184 void set_dimensions (UINT, UINT);
00185 inline void set_dimensions (const LearnModel& l) {
00186 set_dimensions(l.n_input(), l.n_output()); }
00187 inline void set_dimensions (const DataSet& d) {
00188 assert(exact_dimensions(d));
00189 set_dimensions(d.x(0).size(), d.y(0).size()); }
00190
00191 virtual bool serialize (std::ostream&, ver_list&) const;
00192 virtual bool unserialize (std::istream&, ver_list&, const id_t& = NIL_ID);
00193 };
00194
00195 typedef var_shared_ptr<LearnModel> pLearnModel;
00196 typedef const_shared_ptr<LearnModel> pcLearnModel;
00197
00198 }
00199
00200 #ifdef __LEARNMODEL_H__
00201 #warning "This header file may conflict with another `learnmodel.h' file."
00202 #endif
00203 #define __LEARNMODEL_H__
00204 #endif