00001
00002 #ifndef __LEMGA_LEARNMODEL_H__
00003 #define __LEMGA_LEARNMODEL_H__
00004
00013 #include <assert.h>
00014 #include <vector>
00015 #include "object.h"
00016 #include "dataset.h"
00017 #include "shared_ptr.h"
00018
00019 #define VERBOSE_OUTPUT 1
00020
00021 namespace lemga {
00022
00023 typedef std::vector<REAL> Input;
00024 typedef std::vector<REAL> Output;
00025
00026 typedef dataset<Input,Output> DataSet;
00027 typedef std::vector<REAL> DataWgt;
00028 typedef const_shared_ptr<DataSet> pDataSet;
00029 typedef const_shared_ptr<DataWgt> pDataWgt;
00030
00032 DataSet* load_data (std::istream&, UINT, UINT, UINT);
00033 DataSet* load_data (std::istream&, UINT);
00034
00064 class LearnModel : public Object {
00065 protected:
00066 UINT _n_in;
00067 UINT _n_out;
00068 pDataSet ptd;
00069 pDataWgt ptw;
00070 UINT n_samples;
00071
00072 FILE* logf;
00073
00074 public:
00075 LearnModel (UINT n_in = 0, UINT n_out = 0);
00076 LearnModel (const LearnModel&);
00077
00079 virtual LearnModel* create () const = 0;
00080 virtual LearnModel* clone () const = 0;
00081
00082 UINT n_input () const { return _n_in; }
00083 UINT n_output () const { return _n_out; }
00084
00085 void set_log_file (FILE* f) { logf = f; }
00087
00089
00095 virtual bool support_weighted_data () const { return false; }
00096
00098 virtual REAL r_error (const Output& out, const Output& y) const;
00100 virtual REAL c_error (const Output& out, const Output& y) const;
00101
00103 REAL train_r_error () const;
00105 REAL train_c_error () const;
00107 REAL test_r_error (const pDataSet&) const;
00109 REAL test_c_error (const pDataSet&) const;
00110
00116 virtual void initialize () {}
00117
00119 virtual void set_train_data (const pDataSet&, const pDataWgt& = 0);
00121 const pDataSet& train_data () const { return ptd; }
00122 const pDataWgt& data_weight () const { return ptw; }
00123
00128 virtual REAL train () = 0;
00130
00131 virtual Output operator() (const Input&) const = 0;
00132
00136 virtual Output get_output (UINT idx) const {
00137 assert(ptw != NULL);
00138 return operator()(ptd->x(idx)); }
00139
00141
00155 virtual REAL margin_norm () const { return 1; }
00157 virtual REAL margin_of (const Input& x, const Output& y) const;
00161 virtual REAL margin (UINT i) const {
00162 assert(ptw != NULL);
00163 return margin_of(ptd->x(i), ptd->y(i)); }
00165 REAL min_margin () const;
00167
00168 protected:
00169 virtual bool serialize (std::ostream&, ver_list&) const;
00170 virtual bool unserialize (std::istream&, ver_list&,
00171 const id_t& = empty_id);
00172 };
00173
00174 }
00175
00176 #ifdef __LEARNMODEL_H__
00177 #warning "This header file may conflict with another `learnmodel.h' file."
00178 #endif
00179 #define __LEARNMODEL_H__
00180 #endif