learnmodel.h

Go to the documentation of this file.
00001 // -*- C++ -*-
00002 #ifndef __LEMGA_LEARNMODEL_H__
00003 #define __LEMGA_LEARNMODEL_H__
00004 
00013 #include <assert.h>
00014 #include <vector>
00015 #include "object.h"
00016 #include "dataset.h"
00017 #include "shared_ptr.h"
00018 
00019 #define VERBOSE_OUTPUT  1
00020 
00021 namespace lemga {
00022 
00023 typedef std::vector<REAL> Input;
00024 typedef std::vector<REAL> Output;
00025 
00026 typedef dataset<Input,Output> DataSet;
00027 typedef std::vector<REAL> DataWgt;
00028 typedef const_shared_ptr<DataSet> pDataSet;
00029 typedef const_shared_ptr<DataWgt> pDataWgt;
00030 
00032 DataSet* load_data (std::istream&, UINT, UINT, UINT);
00033 DataSet* load_data (std::istream&, UINT);
00034 
00064 class LearnModel : public Object {
00065 protected:
00066     UINT _n_in;       
00067     UINT _n_out;      
00068     pDataSet ptd;     
00069     pDataWgt ptw;     
00070     UINT n_samples;   
00071 
00072     FILE* logf;       
00073 
00074 public:
00075     LearnModel (UINT n_in = 0, UINT n_out = 0);
00076 
00078     virtual LearnModel* create () const = 0;
00079     virtual LearnModel* clone () const = 0;
00080 
00081     UINT n_input  () const { return _n_in;  }
00082     UINT n_output () const { return _n_out; }
00083 
00084     void set_log_file (FILE* f) { logf = f; }
00086 
00088 
00094     virtual bool support_weighted_data () const { return false; }
00095 
00097     virtual REAL r_error (const Output& out, const Output& y) const;
00099     virtual REAL c_error (const Output& out, const Output& y) const;
00100 
00102     REAL train_r_error () const;
00104     REAL train_c_error () const;
00106     REAL test_r_error (const pDataSet&) const;
00108     REAL test_c_error (const pDataSet&) const;
00109 
00110     virtual void initialize () {
00111         std::cerr << "!!! initialize() is depreciated.\n"
00112                   << "!!! See the documentation of LearnModel and reset().\n";
00113     }
00114 
00116     virtual void set_train_data (const pDataSet&, const pDataWgt& = 0);
00118     const pDataSet& train_data () const { return ptd; }
00119     /* temporarily disabled; ptw in boosting base learners is reset
00120      * after the training; disabled to be sure no one actually uses it
00121     const pDataWgt& data_weight () const { return ptw; }
00122      */
00123 
00126     virtual void train () = 0;
00127 
00131     virtual void reset ();
00133 
00134     virtual Output operator() (const Input&) const = 0;
00135 
00139     virtual Output get_output (UINT idx) const {
00140         assert(ptw != 0); // no data sampling
00141         return operator()(ptd->x(idx)); }
00142 
00144 
00158     virtual REAL margin_norm () const { return 1; }
00160     virtual REAL margin_of (const Input& x, const Output& y) const;
00164     virtual REAL margin (UINT i) const {
00165         assert(ptw != 0); // no data sampling
00166         return margin_of(ptd->x(i), ptd->y(i)); }
00168     REAL min_margin () const;
00170 
00171     bool valid_dimensions (UINT, UINT) const;
00172     inline bool valid_dimensions (const LearnModel& l) const {
00173         return valid_dimensions(l.n_input(), l.n_output()); }
00174 
00175     inline bool exact_dimensions (UINT i, UINT o) const {
00176         return (i > 0 && o > 0 && valid_dimensions(i, o)); }
00177     inline bool exact_dimensions (const LearnModel& l) const {
00178         return exact_dimensions(l.n_input(), l.n_output()); }
00179     inline bool exact_dimensions (const DataSet& d) const {
00180         assert(d.size() > 0);
00181         return exact_dimensions(d.x(0).size(), d.y(0).size()); }
00182 
00183 protected:
00184     void set_dimensions (UINT, UINT);
00185     inline void set_dimensions (const LearnModel& l) {
00186         set_dimensions(l.n_input(), l.n_output()); }
00187     inline void set_dimensions (const DataSet& d) {
00188         assert(exact_dimensions(d));
00189         set_dimensions(d.x(0).size(), d.y(0).size()); }
00190 
00191     virtual bool serialize (std::ostream&, ver_list&) const;
00192     virtual bool unserialize (std::istream&, ver_list&, const id_t& = NIL_ID);
00193 };
00194 
00195 typedef var_shared_ptr<LearnModel> pLearnModel;
00196 typedef const_shared_ptr<LearnModel> pcLearnModel;
00197 
00198 } // namespace lemga
00199 
00200 #ifdef  __LEARNMODEL_H__
00201 #warning "This header file may conflict with another `learnmodel.h' file."
00202 #endif
00203 #define __LEARNMODEL_H__
00204 #endif

Generated on Wed Nov 8 08:15:21 2006 for LEMGA by  doxygen 1.4.6