kernel.h

Go to the documentation of this file.
00001 // -*- C++ -*-
00002 #ifndef __LEMGA_KERNEL_H__
00003 #define __LEMGA_KERNEL_H__
00004 
00011 #include <cmath>
00012 #include <numeric>
00013 #include "learnmodel.h"
00014 
00015 #define DOTPROD(x,y) std::inner_product(x.begin(), x.end(), y.begin(), .0)
00016 
00017 namespace lemga {
00018 
00019 // only for LIBSVM, see Kernel::set_params.
00020 struct SVM_detail;
00021 
00022 namespace kernel {
00023 
00024 inline REAL norm_1 (const Input& u, const Input& v) {
00025     REAL sum(0);
00026     Input::const_iterator x = u.begin(), y = v.begin();
00027     for (; x != u.end(); ++x, ++y)
00028         sum += std::fabs(*x - *y);
00029     return sum;
00030 }
00031 
00032 inline REAL norm_2 (const Input& u, const Input& v) {
00033     REAL sum(0);
00034     Input::const_iterator x = u.begin(), y = v.begin();
00035     for (; x != u.end(); ++x, ++y) {
00036         REAL d = *x - *y;
00037         sum += d * d;
00038     }
00039     return sum;
00040 }
00041 
00043 class Kernel : public Object {
00044 public:
00045     virtual Kernel* create () const = 0;
00046     virtual Kernel* clone () const = 0;
00047 
00049     virtual REAL operator() (const Input&, const Input&) const = 0;
00051     virtual void set_data (const pDataSet& pd) { ptd = pd; }
00053     virtual REAL matrix (UINT i, UINT j) const
00054     { return operator()(ptd->x(i), ptd->x(j)); }
00055 
00059     virtual void set_params (SVM_detail*) const = 0;
00060 
00061 protected:
00062     pDataSet ptd;
00063     virtual bool serialize (std::ostream&, ver_list&) const;
00064     virtual bool unserialize (std::istream&, ver_list&, const id_t& = NIL_ID);
00065 };
00066 
00068 struct Linear : public Kernel {
00069     Linear () {}
00070     explicit Linear (std::istream& is) { is >> *this; }
00071 
00072     virtual const id_t& id () const;
00073     virtual Linear* create () const { return new Linear(); }
00074     virtual Linear* clone () const { return new Linear(*this); }
00075 
00076     virtual REAL operator() (const Input& a, const Input& b) const {
00077         return DOTPROD(a, b);
00078     }
00079     virtual void set_params (SVM_detail*) const;
00080 };
00081 
00083 struct Polynomial : public Kernel {
00084     UINT degree;
00085     REAL gamma, coef0;
00086 
00087     Polynomial (UINT d = 3, REAL g = 0.5, REAL c0 = 0)
00088         : degree(d), gamma(g), coef0(c0) {};
00089     explicit Polynomial (std::istream& is) { is >> *this; }
00090 
00091     virtual const id_t& id () const;
00092     virtual Polynomial* create () const { return new Polynomial(); }
00093     virtual Polynomial* clone () const { return new Polynomial(*this); }
00094 
00095     virtual REAL operator() (const Input& a, const Input& b) const {
00096         return std::pow(gamma * DOTPROD(a, b) + coef0, (double) degree);
00097     }
00098     virtual void set_params (SVM_detail*) const;
00099 
00100 protected:
00101     virtual bool serialize (std::ostream&, ver_list&) const;
00102     virtual bool unserialize (std::istream&, ver_list&, const id_t& = NIL_ID);
00103 };
00104 
00106 struct Stump : public Kernel {
00107     Stump () {}
00108     explicit Stump (std::istream& is) { is >> *this; }
00109 
00110     virtual const id_t& id () const;
00111     virtual Stump* create () const { return new Stump(); }
00112     virtual Stump* clone () const { return new Stump(*this); }
00113 
00114     virtual REAL operator() (const Input& a, const Input& b) const {
00115         return -norm_1(a, b);
00116     }
00117     virtual void set_params (SVM_detail*) const;
00118 };
00119 
00121 struct Perceptron : public Kernel {
00122 protected:
00123     std::vector<REAL> x_norm2; 
00124 
00125 public:
00126     Perceptron () {}
00127     explicit Perceptron (std::istream& is) { is >> *this; }
00128 
00129     virtual const id_t& id () const;
00130     virtual Perceptron* create () const { return new Perceptron(); }
00131     virtual Perceptron* clone () const { return new Perceptron(*this); }
00132 
00133     virtual REAL operator() (const Input& a, const Input& b) const {
00134         return -std::sqrt(norm_2(a, b));
00135     }
00136 
00137     virtual void set_data (const pDataSet& pd) {
00138         Kernel::set_data(pd);
00139         const UINT n = ptd->size();
00140         x_norm2.resize(n);
00141         for (UINT i = 0; i < n; ++i)
00142             x_norm2[i] = DOTPROD(ptd->x(i), ptd->x(i));
00143     }
00144     virtual REAL matrix (UINT i, UINT j) const {
00145         REAL n2 = x_norm2[i] + x_norm2[j] - 2*DOTPROD(ptd->x(i), ptd->x(j));
00146         return (n2 > 0)? -std::sqrt(n2) : 0;   // avoid -0.0
00147     }
00148 
00149     virtual void set_params (SVM_detail*) const;
00150 };
00151 
00153 struct RBF : public Perceptron {
00154     REAL gamma;
00155     explicit RBF (REAL g = 0.5) : gamma(g) {}
00156     explicit RBF (std::istream& is) { is >> *this; }
00157 
00158     virtual const id_t& id () const;
00159     virtual RBF* create () const { return new RBF(); }
00160     virtual RBF* clone () const { return new RBF(*this); }
00161 
00162     virtual REAL operator() (const Input& a, const Input& b) const {
00163         return std::exp(-gamma * norm_2(a, b));
00164     }
00165 
00166     virtual REAL matrix (UINT i, UINT j) const {
00167         REAL n2 = x_norm2[i] + x_norm2[j] - 2*DOTPROD(ptd->x(i), ptd->x(j));
00168         return std::exp(-gamma * n2);
00169     }
00170 
00171     virtual void set_params (SVM_detail*) const;
00172 
00173 protected:
00174     virtual bool serialize (std::ostream&, ver_list&) const;
00175     virtual bool unserialize (std::istream&, ver_list&, const id_t& = NIL_ID);
00176 };
00177 
00179 struct Sigmoid : public Kernel {
00180     REAL gamma, coef0;
00181     Sigmoid (REAL g = 0.5, REAL c0 = 0) : gamma(g), coef0(c0) {};
00182     explicit Sigmoid (std::istream& is) { is >> *this; }
00183 
00184     virtual const id_t& id () const;
00185     virtual Sigmoid* create () const { return new Sigmoid(); }
00186     virtual Sigmoid* clone () const { return new Sigmoid(*this); }
00187 
00188     virtual REAL operator() (const Input& a, const Input& b) const {
00189         return std::tanh(gamma * DOTPROD(a, b) + coef0);
00190     }
00191     virtual void set_params (SVM_detail*) const;
00192 
00193 protected:
00194     virtual bool serialize (std::ostream&, ver_list&) const;
00195     virtual bool unserialize (std::istream&, ver_list&, const id_t& = NIL_ID);
00196 };
00197 
00198 }} // namespace lemga::kernel
00199 
00200 #ifdef  __KERNEL_H__
00201 #warning "This header file may conflict with another `kernel.h' file."
00202 #endif
00203 #define __KERNEL_H__
00204 #endif

Generated on Wed Nov 8 08:15:21 2006 for LEMGA by  doxygen 1.4.6