00001
00002 #ifndef __LEMGA_MULTICLASS_ECOC_H__
00003 #define __LEMGA_MULTICLASS_ECOC_H__
00004
00012 #include <vector>
00013 #include "aggregating.h"
00014
00015
00016 #define MULTICLASS_ECOC_OUTPUT_CACHE 2
00017
00018 namespace lemga {
00019
00020 typedef std::vector<int> ECOC_VECTOR;
00021 typedef std::vector<ECOC_VECTOR> ECOC_TABLE;
00022 enum ECOC_TYPE {
00023 NO_TYPE,
00024 ONE_VS_ONE,
00025 ONE_VS_ALL
00026 };
00027
00030 class MultiClass_ECOC : public Aggregating {
00031 protected:
00032 std::vector<REAL> lm_wgt;
00033 ECOC_TABLE ecoc;
00034
00035
00036 ECOC_TYPE ecoc_type;
00037
00038
00039 UINT nclass;
00040 std::vector<REAL> labels;
00041 std::vector<UINT> ex_class;
00042
00043 public:
00044 MultiClass_ECOC () : Aggregating(), ecoc_type(NO_TYPE), nclass(0)
00045 { set_dimensions(0, 1); }
00046 explicit MultiClass_ECOC (std::istream& is) : ecoc_type(NO_TYPE)
00047 { is >> *this; }
00048
00049 virtual const id_t& id () const;
00050 virtual MultiClass_ECOC* create () const { return new MultiClass_ECOC(); }
00051 virtual MultiClass_ECOC* clone () const {
00052 return new MultiClass_ECOC(*this); }
00053
00054 REAL model_weight (UINT n) const { return lm_wgt[n]; }
00055 const ECOC_TABLE& ECOC_table () const { return ecoc; }
00056 void set_ECOC_table (const ECOC_TABLE&);
00057 void set_ECOC_table (ECOC_TYPE);
00058 void set_ECOC_table (UINT, const ECOC_VECTOR&);
00059 UINT n_class () const { return nclass; }
00060
00061 virtual bool support_weighted_data () const { return true; }
00062 virtual REAL c_error (const Output& out, const Output& y) const;
00063 virtual void set_train_data (const pDataSet&, const pDataWgt& = 0);
00064 virtual void train ();
00065 virtual void reset ();
00066 virtual Output operator() (const Input&) const;
00067 virtual Output get_output (UINT idx) const;
00068
00069 virtual REAL margin (UINT) const;
00070 virtual REAL margin_of (const Input&, const Output&) const;
00072 REAL cost () const;
00073
00074 #if MULTICLASS_ECOC_OUTPUT_CACHE
00075 private:
00076 #if MULTICLASS_ECOC_OUTPUT_CACHE == 2 // distance cache
00077 mutable std::vector<UINT> cache_n;
00078 mutable std::vector<std::vector<REAL> > cache_d;
00079 inline void clear_cache (UINT i) const {
00080 assert(i < n_samples && cache_n.size() == n_samples && nclass > 0);
00081 cache_n[i] = 0;
00082 std::vector<REAL> cdi(nclass, 0);
00083 cache_d[i].swap(cdi);
00084 }
00085 inline void clear_cache () const {
00086 cache_d.resize(n_samples);
00087 cache_n.resize(n_samples);
00088 for (UINT i = 0; i < n_samples; ++i)
00089 clear_cache(i);
00090 }
00091 #elif MULTICLASS_ECOC_OUTPUT_CACHE == 1
00092 mutable std::vector<Output> cache_o;
00093 inline void clear_cache () const {
00094 std::vector<Output> co(n_samples);
00095 cache_o.swap(co);
00096 }
00097 #else
00098 #error "Wrong value of MULTICLASS_ECOC_OUTPUT_CACHE in `multiclass_ecoc.h'."
00099 #endif
00100 #endif // MULTICLASS_ECOC_OUTPUT_CACHE
00101
00102 protected:
00103 virtual REAL ECOC_distance (const Output&, const ECOC_VECTOR&) const;
00104 #if MULTICLASS_ECOC_OUTPUT_CACHE == 2 // distance cache
00105 virtual REAL ECOC_distance (REAL, int, REAL, REAL = 0) const;
00106 #endif
00107
00108 mutable std::vector<REAL> local_d;
00111 const std::vector<REAL>& distances (const Input&) const;
00112 const std::vector<REAL>& distances (UINT) const;
00113
00115 bool is_full_partition (const ECOC_VECTOR&) const;
00116
00118 virtual void setup_aux () {}
00119 virtual bool ECOC_partition (UINT, ECOC_VECTOR&) const;
00120 virtual pLearnModel train_with_partition (ECOC_VECTOR&) const;
00121 virtual REAL assign_weight (const ECOC_VECTOR&, const LearnModel&) const
00122 { return 1; }
00124 virtual void update_aux (const ECOC_VECTOR&) {}
00125
00126 virtual bool serialize (std::ostream&, ver_list&) const;
00127 virtual bool unserialize (std::istream&, ver_list&, const id_t& = NIL_ID);
00128 };
00129
00130 }
00131
00132 #ifdef __MULTICLASS_ECOC_H__
00133 #warning "This header file may conflict with another `multiclass_ecoc.h' file."
00134 #endif
00135 #define __MULTICLASS_ECOC_H__
00136 #endif