kernel.h

Go to the documentation of this file.
00001 // -*- C++ -*-
00002 #ifndef __LEMGA_KERNEL_H__
00003 #define __LEMGA_KERNEL_H__
00004 
00011 #include <cmath>
00012 #include <numeric>
00013 #include "learnmodel.h"
00014 
00015 #define DOTPROD(x,y) std::inner_product(x.begin(), x.end(), y.begin(), .0)
00016 
00017 namespace lemga {
00018 
00019 // only for LIBSVM, see Kernel::set_params.
00020 struct SVM_detail;
00021 
00022 namespace kernel {
00023 
00024 inline REAL norm_1 (const Input& u, const Input& v) {
00025     REAL sum(0);
00026     Input::const_iterator x = u.begin(), y = v.begin();
00027     for (; x != u.end(); ++x, ++y)
00028         sum += std::fabs(*x - *y);
00029     return sum;
00030 }
00031 
00032 inline REAL norm_2 (const Input& u, const Input& v) {
00033     REAL sum(0);
00034     Input::const_iterator x = u.begin(), y = v.begin();
00035     for (; x != u.end(); ++x, ++y) {
00036         REAL d = *x - *y;
00037         sum += d * d;
00038     }
00039     return sum;
00040 }
00041 
00043 class Kernel {
00044 protected:
00045     pDataSet ptd;
00046 
00047 public:
00048     virtual ~Kernel () { /* get rid of GCC warnings */ }
00049 
00051     virtual REAL operator() (const Input&, const Input&) const = 0;
00053     virtual void set_data (const pDataSet& pd) { ptd = pd; }
00055     virtual REAL matrix (UINT i, UINT j) const
00056     { return operator()(ptd->x(i), ptd->x(j)); }
00057 
00061     virtual void set_params (SVM_detail*) const = 0;
00062 };
00063 
00064 struct Linear : public Kernel {
00065     virtual REAL operator() (const Input& a, const Input& b) const {
00066         return DOTPROD(a, b);
00067     }
00068     virtual void set_params (SVM_detail*) const;
00069 };
00070 
00071 struct Polynomial : public Kernel {
00072     UINT degree;
00073     REAL gamma, coef0;
00074     Polynomial (UINT d, REAL g, REAL c0)
00075         : degree(d), gamma(g), coef0(c0) {};
00076 
00077     virtual REAL operator() (const Input& a, const Input& b) const {
00078         return std::pow(gamma * DOTPROD(a, b) + coef0, (REAL) degree);
00079     }
00080     virtual void set_params (SVM_detail*) const;
00081 };
00082 
00083 struct Stump : public Kernel {
00084     virtual REAL operator() (const Input& a, const Input& b) const {
00085         return -norm_1(a, b);
00086     }
00087     virtual void set_params (SVM_detail*) const;
00088 };
00089 
00090 struct Perceptron : public Kernel {
00091 protected:
00092     std::vector<REAL> x_norm2; 
00093 
00094 public:
00095     virtual REAL operator() (const Input& a, const Input& b) const {
00096         return -std::sqrt(norm_2(a, b));
00097     }
00098 
00099     virtual void set_data (const pDataSet& pd) {
00100         Kernel::set_data(pd);
00101         const UINT n = ptd->size();
00102         x_norm2.resize(n);
00103         for (UINT i = 0; i < n; ++i)
00104             x_norm2[i] = DOTPROD(ptd->x(i), ptd->x(i));
00105     }
00106     virtual REAL matrix (UINT i, UINT j) const {
00107         REAL n2 = x_norm2[i] + x_norm2[j] - 2*DOTPROD(ptd->x(i), ptd->x(j));
00108         return (n2 > 0)? -std::sqrt(n2) : 0;   // avoid -0.0
00109     }
00110 
00111     virtual void set_params (SVM_detail*) const;
00112 };
00113 
00114 struct RBF : public Perceptron {
00115     REAL gamma;
00116     explicit RBF (REAL g) : gamma(g) {}
00117 
00118     virtual REAL operator() (const Input& a, const Input& b) const {
00119         return std::exp(-gamma * norm_2(a, b));
00120     }
00121 
00122     virtual REAL matrix (UINT i, UINT j) const {
00123         REAL n2 = x_norm2[i] + x_norm2[j] - 2*DOTPROD(ptd->x(i), ptd->x(j));
00124         return std::exp(-gamma * n2);
00125     }
00126 
00127     virtual void set_params (SVM_detail*) const;
00128 };
00129 
00130 struct Sigmoid : public Kernel {
00131     REAL gamma, coef0;
00132     Sigmoid (REAL g, REAL c0) : gamma(g), coef0(c0) {};
00133 
00134     virtual REAL operator() (const Input& a, const Input& b) const {
00135         return std::tanh(gamma * DOTPROD(a, b) + coef0);
00136     }
00137     virtual void set_params (SVM_detail*) const;
00138 };
00139 
00140 }} // namespace lemga::kernel
00141 
00142 #ifdef  __KERNEL_H__
00143 #warning "This header file may conflict with another `kernel.h' file."
00144 #endif
00145 #define __KERNEL_H__
00146 #endif

Generated on Mon Jan 9 23:43:24 2006 for LEMGA by  doxygen 1.4.6