00001
00002 #ifndef __LEMGA_PERCEPTRON_H__
00003 #define __LEMGA_PERCEPTRON_H__
00004
00011 #include <vector>
00012 #include "learnmodel.h"
00013 #include "svm.h"
00014
00015 namespace lemga {
00016
00035 class Perceptron : public LearnModel {
00036 public:
00037 typedef std::vector<REAL> WEIGHT;
00038 enum TRAIN_METHOD {
00039
00040 PERCEPTRON,
00041 ADALINE,
00042 POCKET,
00043 POCKET_RATCHET,
00044 AVE_PERCEPTRON,
00045 ROMMA,
00046 ROMMA_AGG,
00047 SGD_HINGE,
00048 SGD_MLSE,
00049
00050 RCD,
00051 RCD_BIAS,
00052 RCD_GRAD,
00053
00054 AVE_PERCEPTRON_RAND,
00055 ROMMA_RAND,
00056 ROMMA_AGG_RAND,
00057 COORDINATE_DESCENT,
00058 FIXED_RCD,
00059 FIXED_RCD_CONJ,
00060 FIXED_RCD_BIAS,
00061 FIXED_RCD_CONJ_BIAS,
00062 RCD_CONJ,
00063 RCD_CONJ_BIAS,
00064 RCD_GRAD_BATCH,
00065 RCD_GRAD_RAND,
00066 RCD_GRAD_BATCH_RAND,
00067 RCD_MIXED,
00068 RCD_GRAD_MIXED,
00069 RCD_GRAD_MIXED_INITRAND,
00070 RCD_GRAD_MIXED_BATCH,
00071 RCD_GRAD_MIXED_BATCH_INITRAND,
00072
00073 RAND_COOR_DESCENT = RCD,
00074 RAND_COOR_DESCENT_BIAS = RCD_BIAS,
00075 RAND_CONJ_DESCENT = RCD_CONJ,
00076 RAND_CONJ_DESCENT_BIAS = RCD_CONJ_BIAS,
00077 GRADIENT_COOR_DESCENT_ONLINE = RCD_GRAD
00078 };
00079
00080 protected:
00081 WEIGHT wgt;
00082
00083 bool resample;
00084 TRAIN_METHOD train_method;
00085 REAL learn_rate, min_cst;
00086 UINT max_run;
00087 bool with_fld;
00088 bool fixed_bias;
00089
00090 public:
00091 explicit Perceptron (UINT n_in = 0);
00092 Perceptron (const SVM&);
00093 explicit Perceptron (std::istream& is) { is >> *this; }
00094
00095 virtual const id_t& id () const;
00096 virtual Perceptron* create () const { return new Perceptron(); }
00097 virtual Perceptron* clone () const { return new Perceptron(*this); }
00098
00099 WEIGHT weight () const { return wgt; }
00100 void set_weight (const WEIGHT&);
00101
00102 void start_with_fld (bool b = true) { with_fld = b; }
00103 void set_fixed_bias (bool b = false) { fixed_bias = b; }
00104 void use_resample (bool s = true) { resample = true; }
00105 void set_train_method (TRAIN_METHOD m) { train_method = m; }
00111 void set_parameter (REAL lr, REAL mincst, UINT maxrun) {
00112 learn_rate = lr; min_cst = mincst; max_run = maxrun; }
00113
00114 virtual bool support_weighted_data () const { return true; }
00115 virtual void initialize ();
00116 WEIGHT fld () const;
00117 virtual void train ();
00118 virtual Output operator() (const Input&) const;
00119
00120 virtual REAL margin_norm () const { return w_norm(); }
00121 virtual REAL margin_of (const Input&, const Output&) const;
00122 REAL w_norm () const;
00123
00124 protected:
00125 virtual bool serialize (std::ostream&, ver_list&) const;
00126 virtual bool unserialize (std::istream&, ver_list&, const id_t& = NIL_ID);
00127 virtual void log_error (UINT, REAL = -1) const;
00128 };
00129
00130 }
00131
00132 #ifdef __PERCEPTRON_H__
00133 #warning "This header file may conflict with another `perceptron.h' file."
00134 #endif
00135 #define __PERCEPTRON_H__
00136 #endif