00001
00002 #ifndef __LEMGA_PERCEPTRON_H__
00003 #define __LEMGA_PERCEPTRON_H__
00004
00011 #include <vector>
00012 #include "learnmodel.h"
00013 #include "svm.h"
00014
00015 namespace lemga {
00016
00035 class Perceptron : public LearnModel {
00036 public:
00037 typedef std::vector<REAL> WEIGHT;
00038 enum TRAIN_METHOD {
00039
00040 PERCEPTRON,
00041 ADALINE,
00042 POCKET,
00043 POCKET_RATCHET,
00044 AVE_PERCEPTRON,
00045 ROMMA,
00046 ROMMA_AGG,
00047 SGD_HINGE,
00048 SGD_MLSE,
00049
00050 RCD,
00051 RCD_BIAS,
00052 RCD_GRAD,
00053
00054 AVE_PERCEPTRON_RAND,
00055 ROMMA_RAND,
00056 ROMMA_AGG_RAND,
00057 COORDINATE_DESCENT,
00058 FIXED_RCD,
00059 FIXED_RCD_CONJ,
00060 FIXED_RCD_BIAS,
00061 FIXED_RCD_CONJ_BIAS,
00062 RCD_CONJ,
00063 RCD_CONJ_BIAS,
00064 RCD_GRAD_BATCH,
00065 RCD_GRAD_RAND,
00066 RCD_GRAD_BATCH_RAND,
00067 RCD_MIXED,
00068 RCD_GRAD_MIXED,
00069 RCD_GRAD_MIXED_INITRAND,
00070 RCD_GRAD_MIXED_BATCH,
00071 RCD_GRAD_MIXED_BATCH_INITRAND,
00072
00073 RAND_COOR_DESCENT = RCD,
00074 RAND_COOR_DESCENT_BIAS = RCD_BIAS,
00075 RAND_CONJ_DESCENT = RCD_CONJ,
00076 RAND_CONJ_DESCENT_BIAS = RCD_CONJ_BIAS,
00077 GRADIENT_COOR_DESCENT_ONLINE = RCD_GRAD
00078 };
00079
00080 protected:
00081 WEIGHT wgt;
00082
00083 bool resample;
00084 TRAIN_METHOD train_method;
00085 REAL learn_rate, min_cst;
00086 UINT max_run;
00087 bool with_fld;
00088
00089 public:
00090 explicit Perceptron (UINT n_in = 0);
00091 Perceptron (const SVM&);
00092 explicit Perceptron (std::istream& is) { is >> *this; }
00093
00094 virtual const id_t& id () const;
00095 virtual Perceptron* create () const { return new Perceptron(); }
00096 virtual Perceptron* clone () const { return new Perceptron(*this); }
00097
00098 WEIGHT weight () const { return wgt; }
00099 void set_weight (const WEIGHT&);
00100
00101 void start_with_fld (bool b = true) { with_fld = b; }
00102 void use_resample (bool s = true) { resample = true; }
00103 void set_train_method (TRAIN_METHOD m) { train_method = m; }
00109 void set_parameter (REAL lr, REAL mincst, UINT maxrun) {
00110 learn_rate = lr; min_cst = mincst; max_run = maxrun; }
00111
00112 virtual bool support_weighted_data () const { return true; }
00113 virtual void initialize ();
00114 WEIGHT fld () const;
00115 virtual REAL train ();
00116 virtual Output operator() (const Input&) const;
00117
00118 virtual REAL margin_norm () const { return w_norm(); }
00119 virtual REAL margin_of (const Input&, const Output&) const;
00120 REAL w_norm () const;
00121
00122 protected:
00123 virtual bool serialize (std::ostream&, ver_list&) const;
00124 virtual bool unserialize (std::istream&, ver_list&,
00125 const id_t& = empty_id);
00126 virtual void log_error (UINT, REAL = -1) const;
00127 };
00128
00132 bool ldivide (std::vector<std::vector<REAL> >& A,
00133 const std::vector<REAL>& b, std::vector<REAL>& x);
00134
00135 }
00136
00137 #ifdef __PERCEPTRON_H__
00138 #warning "This header file may conflict with another `perceptron.h' file."
00139 #endif
00140 #define __PERCEPTRON_H__
00141 #endif