00001
00002 #ifndef __LEMGA_KERNEL_H__
00003 #define __LEMGA_KERNEL_H__
00004
00011 #include <cmath>
00012 #include <numeric>
00013 #include "learnmodel.h"
00014
00015 #define DOTPROD(x,y) std::inner_product(x.begin(), x.end(), y.begin(), .0)
00016
00017 namespace lemga {
00018
00019
00020 struct SVM_detail;
00021
00022 namespace kernel {
00023
00024 inline REAL norm_1 (const Input& u, const Input& v) {
00025 REAL sum(0);
00026 Input::const_iterator x = u.begin(), y = v.begin();
00027 for (; x != u.end(); ++x, ++y)
00028 sum += std::fabs(*x - *y);
00029 return sum;
00030 }
00031
00032 inline REAL norm_2 (const Input& u, const Input& v) {
00033 REAL sum(0);
00034 Input::const_iterator x = u.begin(), y = v.begin();
00035 for (; x != u.end(); ++x, ++y) {
00036 REAL d = *x - *y;
00037 sum += d * d;
00038 }
00039 return sum;
00040 }
00041
00043 class Kernel : public Object {
00044 public:
00045 virtual Kernel* create () const = 0;
00046 virtual Kernel* clone () const = 0;
00047
00049 virtual REAL operator() (const Input&, const Input&) const = 0;
00051 virtual void set_data (const pDataSet& pd) { ptd = pd; }
00053 virtual REAL matrix (UINT i, UINT j) const
00054 { return operator()(ptd->x(i), ptd->x(j)); }
00055
00059 virtual void set_params (SVM_detail*) const = 0;
00060
00061 protected:
00062 pDataSet ptd;
00063 virtual bool serialize (std::ostream&, ver_list&) const;
00064 virtual bool unserialize (std::istream&, ver_list&, const id_t& = NIL_ID);
00065 };
00066
00068 struct Linear : public Kernel {
00069 Linear () {}
00070 explicit Linear (std::istream& is) { is >> *this; }
00071
00072 virtual const id_t& id () const;
00073 virtual Linear* create () const { return new Linear(); }
00074 virtual Linear* clone () const { return new Linear(*this); }
00075
00076 virtual REAL operator() (const Input& a, const Input& b) const {
00077 return DOTPROD(a, b);
00078 }
00079 virtual void set_params (SVM_detail*) const;
00080 };
00081
00083 struct Polynomial : public Kernel {
00084 UINT degree;
00085 REAL gamma, coef0;
00086
00087 Polynomial (UINT d = 3, REAL g = 0.5, REAL c0 = 0)
00088 : degree(d), gamma(g), coef0(c0) {};
00089 explicit Polynomial (std::istream& is) { is >> *this; }
00090
00091 virtual const id_t& id () const;
00092 virtual Polynomial* create () const { return new Polynomial(); }
00093 virtual Polynomial* clone () const { return new Polynomial(*this); }
00094
00095 virtual REAL operator() (const Input& a, const Input& b) const {
00096 return std::pow(gamma * DOTPROD(a, b) + coef0, (double) degree);
00097 }
00098 virtual void set_params (SVM_detail*) const;
00099
00100 protected:
00101 virtual bool serialize (std::ostream&, ver_list&) const;
00102 virtual bool unserialize (std::istream&, ver_list&, const id_t& = NIL_ID);
00103 };
00104
00106 struct Stump : public Kernel {
00107 Stump () {}
00108 explicit Stump (std::istream& is) { is >> *this; }
00109
00110 virtual const id_t& id () const;
00111 virtual Stump* create () const { return new Stump(); }
00112 virtual Stump* clone () const { return new Stump(*this); }
00113
00114 virtual REAL operator() (const Input& a, const Input& b) const {
00115 return -norm_1(a, b);
00116 }
00117 virtual void set_params (SVM_detail*) const;
00118 };
00119
00121 struct Perceptron : public Kernel {
00122 protected:
00123 std::vector<REAL> x_norm2;
00124
00125 public:
00126 Perceptron () {}
00127 explicit Perceptron (std::istream& is) { is >> *this; }
00128
00129 virtual const id_t& id () const;
00130 virtual Perceptron* create () const { return new Perceptron(); }
00131 virtual Perceptron* clone () const { return new Perceptron(*this); }
00132
00133 virtual REAL operator() (const Input& a, const Input& b) const {
00134 return -std::sqrt(norm_2(a, b));
00135 }
00136
00137 virtual void set_data (const pDataSet& pd) {
00138 Kernel::set_data(pd);
00139 const UINT n = ptd->size();
00140 x_norm2.resize(n);
00141 for (UINT i = 0; i < n; ++i)
00142 x_norm2[i] = DOTPROD(ptd->x(i), ptd->x(i));
00143 }
00144 virtual REAL matrix (UINT i, UINT j) const {
00145 REAL n2 = x_norm2[i] + x_norm2[j] - 2*DOTPROD(ptd->x(i), ptd->x(j));
00146 return (n2 > 0)? -std::sqrt(n2) : 0;
00147 }
00148
00149 virtual void set_params (SVM_detail*) const;
00150 };
00151
00153 struct RBF : public Perceptron {
00154 REAL gamma;
00155 explicit RBF (REAL g = 0.5) : gamma(g) {}
00156 explicit RBF (std::istream& is) { is >> *this; }
00157
00158 virtual const id_t& id () const;
00159 virtual RBF* create () const { return new RBF(); }
00160 virtual RBF* clone () const { return new RBF(*this); }
00161
00162 virtual REAL operator() (const Input& a, const Input& b) const {
00163 return std::exp(-gamma * norm_2(a, b));
00164 }
00165
00166 virtual REAL matrix (UINT i, UINT j) const {
00167 REAL n2 = x_norm2[i] + x_norm2[j] - 2*DOTPROD(ptd->x(i), ptd->x(j));
00168 return std::exp(-gamma * n2);
00169 }
00170
00171 virtual void set_params (SVM_detail*) const;
00172
00173 protected:
00174 virtual bool serialize (std::ostream&, ver_list&) const;
00175 virtual bool unserialize (std::istream&, ver_list&, const id_t& = NIL_ID);
00176 };
00177
00179 struct Sigmoid : public Kernel {
00180 REAL gamma, coef0;
00181 Sigmoid (REAL g = 0.5, REAL c0 = 0) : gamma(g), coef0(c0) {};
00182 explicit Sigmoid (std::istream& is) { is >> *this; }
00183
00184 virtual const id_t& id () const;
00185 virtual Sigmoid* create () const { return new Sigmoid(); }
00186 virtual Sigmoid* clone () const { return new Sigmoid(*this); }
00187
00188 virtual REAL operator() (const Input& a, const Input& b) const {
00189 return std::tanh(gamma * DOTPROD(a, b) + coef0);
00190 }
00191 virtual void set_params (SVM_detail*) const;
00192
00193 protected:
00194 virtual bool serialize (std::ostream&, ver_list&) const;
00195 virtual bool unserialize (std::istream&, ver_list&, const id_t& = NIL_ID);
00196 };
00197
00198 }}
00199
00200 #ifdef __KERNEL_H__
00201 #warning "This header file may conflict with another `kernel.h' file."
00202 #endif
00203 #define __KERNEL_H__
00204 #endif