learnmodel.h

Go to the documentation of this file.
00001 // -*- C++ -*-
00002 #ifndef __LEMGA_LEARNMODEL_H__
00003 #define __LEMGA_LEARNMODEL_H__
00004 
00013 #include <assert.h>
00014 #include <vector>
00015 #include "object.h"
00016 #include "dataset.h"
00017 #include "shared_ptr.h"
00018 
00019 #define VERBOSE_OUTPUT  1
00020 
00021 namespace lemga {
00022 
00023 typedef std::vector<REAL> Input;
00024 typedef std::vector<REAL> Output;
00025 
00026 typedef dataset<Input,Output> DataSet;
00027 typedef std::vector<REAL> DataWgt;
00028 typedef const_shared_ptr<DataSet> pDataSet;
00029 typedef const_shared_ptr<DataWgt> pDataWgt;
00030 
00032 DataSet* load_data (std::istream&, UINT, UINT, UINT);
00033 DataSet* load_data (std::istream&, UINT);
00034 
00064 class LearnModel : public Object {
00065 protected:
00066     UINT _n_in;       
00067     UINT _n_out;      
00068     pDataSet ptd;     
00069     pDataWgt ptw;     
00070     UINT n_samples;   
00071 
00072     FILE* logf;       
00073 
00074 public:
00075     LearnModel (UINT n_in = 0, UINT n_out = 0);
00076     LearnModel (const LearnModel&);
00077 
00079     virtual LearnModel* create () const = 0;
00080     virtual LearnModel* clone () const = 0;
00081 
00082     UINT n_input  () const { return _n_in;  }
00083     UINT n_output () const { return _n_out; }
00084 
00085     void set_log_file (FILE* f) { logf = f; }
00087 
00089 
00095     virtual bool support_weighted_data () const { return false; }
00096 
00098     virtual REAL r_error (const Output& out, const Output& y) const;
00100     virtual REAL c_error (const Output& out, const Output& y) const;
00101 
00103     REAL train_r_error () const;
00105     REAL train_c_error () const;
00107     REAL test_r_error (const pDataSet&) const;
00109     REAL test_c_error (const pDataSet&) const;
00110 
00116     virtual void initialize () {}
00117 
00119     virtual void set_train_data (const pDataSet&, const pDataWgt& = 0);
00121     const pDataSet& train_data () const { return ptd; }
00122     const pDataWgt& data_weight () const { return ptw; }
00123 
00128     virtual REAL train () = 0;
00130 
00131     virtual Output operator() (const Input&) const = 0;
00132 
00136     virtual Output get_output (UINT idx) const {
00137         assert(ptw != NULL); // no data sampling
00138         return operator()(ptd->x(idx)); }
00139 
00141 
00155     virtual REAL margin_norm () const { return 1; }
00157     virtual REAL margin_of (const Input& x, const Output& y) const;
00161     virtual REAL margin (UINT i) const {
00162         assert(ptw != NULL); // no data sampling
00163         return margin_of(ptd->x(i), ptd->y(i)); }
00165     REAL min_margin () const;
00167 
00168 protected:
00169     virtual bool serialize (std::ostream&, ver_list&) const;
00170     virtual bool unserialize (std::istream&, ver_list&,
00171                               const id_t& = empty_id);
00172 };
00173 
00174 } // namespace lemga
00175 
00176 #ifdef  __LEARNMODEL_H__
00177 #warning "This header file may conflict with another `learnmodel.h' file."
00178 #endif
00179 #define __LEARNMODEL_H__
00180 #endif

Generated on Mon Jan 9 23:43:24 2006 for LEMGA by  doxygen 1.4.6