multiclass_ecoc.h

Go to the documentation of this file.
00001 // -*- C++ -*-
00002 #ifndef __LEMGA_MULTICLASS_ECOC_H__
00003 #define __LEMGA_MULTICLASS_ECOC_H__
00004 
00012 #include <vector>
00013 #include "aggregating.h"
00014 
00015 // 0 - no cache; 1 - cache output; 2 - cache distance
00016 #define MULTICLASS_ECOC_OUTPUT_CACHE    2
00017 
00018 namespace lemga {
00019 
00020 typedef std::vector<int> ECOC_VECTOR;
00021 typedef std::vector<ECOC_VECTOR> ECOC_TABLE;
00022 enum ECOC_TYPE {
00023     ONE_VS_ONE,
00024     ONE_VS_ALL
00025 };
00026 
00029 class MultiClass_ECOC : public Aggregating {
00030 protected:
00031     std::vector<REAL> lm_wgt;   
00032     ECOC_TABLE ecoc;            
00033 
00034     // variables extracted from the training data
00035     UINT nclass;                
00036     std::vector<REAL> labels;   
00037     std::vector<UINT> ex_class; 
00038 
00039 public:
00040     MultiClass_ECOC () : Aggregating(), nclass(0) {}
00041     explicit MultiClass_ECOC (std::istream& is) { is >> *this; }
00042 
00043     virtual const id_t& id () const;
00044     virtual MultiClass_ECOC* create () const { return new MultiClass_ECOC(); }
00045     virtual MultiClass_ECOC* clone () const {
00046         return new MultiClass_ECOC(*this); }
00047 
00048     REAL model_weight (UINT n) const { return lm_wgt[n]; }
00049     const ECOC_TABLE& ECOC_table () const { return ecoc; }
00050     void set_ECOC_table (const ECOC_TABLE&);
00051     void set_ECOC_table (ECOC_TYPE);
00052     void set_ECOC_table (UINT, const ECOC_VECTOR&);
00053     UINT n_class () const { return nclass; }
00054 
00055     virtual bool support_weighted_data () const { return true; }
00056     virtual REAL c_error (const Output& out, const Output& y) const;
00057     virtual void initialize ();
00058     virtual void set_train_data (const pDataSet&, const pDataWgt& = 0);
00059     virtual REAL train ();
00060     virtual Output operator() (const Input&) const;
00061     virtual Output get_output (UINT idx) const;
00062 
00063     virtual REAL margin (UINT) const;
00064     virtual REAL margin_of (const Input&, const Output&) const;
00066     REAL cost () const;
00067 
00068 #if MULTICLASS_ECOC_OUTPUT_CACHE
00069 private:
00070 #if MULTICLASS_ECOC_OUTPUT_CACHE == 2  // distance cache
00071     mutable std::vector<UINT> cache_n;
00072     mutable std::vector<std::vector<REAL> > cache_d;
00073     inline void clear_cache (UINT i) const {
00074         assert(i < n_samples && cache_n.size() == n_samples && nclass > 0);
00075         cache_n[i] = 0;
00076         std::vector<REAL> cdi(nclass, 0);
00077         cache_d[i].swap(cdi);
00078     }
00079     inline void clear_cache () const {
00080         cache_d.resize(n_samples);
00081         cache_n.resize(n_samples);
00082         for (UINT i = 0; i < n_samples; ++i)
00083             clear_cache(i);
00084     }
00085 #elif MULTICLASS_ECOC_OUTPUT_CACHE == 1
00086     mutable std::vector<Output> cache_o;
00087     inline void clear_cache () const {
00088         std::vector<Output> co(n_samples);
00089         cache_o.swap(co);
00090     }
00091 #else
00092 #error "Wrong value of MULTICLASS_ECOC_OUTPUT_CACHE in `multiclass_ecoc.h'."
00093 #endif
00094 #endif // MULTICLASS_ECOC_OUTPUT_CACHE
00095 
00096 protected:
00097     virtual REAL ECOC_distance (const Output&, const ECOC_VECTOR&) const;
00098 #if MULTICLASS_ECOC_OUTPUT_CACHE == 2  // distance cache
00099     virtual REAL ECOC_distance (REAL, int, REAL, REAL = 0) const;
00100 #endif
00101 
00102     mutable std::vector<REAL> local_d;
00105     const std::vector<REAL>& distances (const Input&) const;
00106     const std::vector<REAL>& distances (UINT) const;
00107 
00109     virtual void reset_training () {}
00110     virtual bool ECOC_partition (UINT, ECOC_VECTOR&);
00111     virtual pLearnModel train_with_partition (ECOC_VECTOR&);
00112     virtual REAL assign_weight (const ECOC_VECTOR&, const LearnModel&);
00114     virtual void update_training (const ECOC_VECTOR&) {}
00115 
00116     virtual bool serialize (std::ostream&, ver_list&) const;
00117     virtual bool unserialize (std::istream&, ver_list&,
00118                               const id_t& = empty_id);
00119 };
00120 
00121 } // namespace lemga
00122 
00123 #ifdef  __MULTICLASS_ECOC_H__
00124 #warning "This header file may conflict with another `multiclass_ecoc.h' file."
00125 #endif
00126 #define __MULTICLASS_ECOC_H__
00127 #endif

Generated on Mon Jan 9 23:43:24 2006 for LEMGA by  doxygen 1.4.6