nnlayer.h

Go to the documentation of this file.
00001 // -*- C++ -*-
00002 #ifndef __LEMGA_NNLAYER_H__
00003 #define __LEMGA_NNLAYER_H__
00004 
00011 #include <vector>
00012 #include "learnmodel.h"
00013 
00014 namespace lemga {
00015 
00053 class NNLayer : public LearnModel {
00054     mutable REAL stored_sigmoid;
00055 
00056 public:
00057     typedef std::vector<REAL> WVEC; 
00058     typedef std::vector<REAL> DVEC; 
00059 
00060 protected:
00061     REAL w_min, w_max;
00062     WVEC w;   
00063     WVEC dw;  
00064     mutable DVEC sig_der;  // f'(wsum)
00065 
00066 public:
00067     explicit NNLayer (UINT n_in = 0, UINT n_unit = 0);
00068     explicit NNLayer (std::istream& is) { is >> *this; }
00069 
00070     virtual const id_t& id () const;
00071     virtual NNLayer* create () const { return new NNLayer(); }
00072     virtual NNLayer* clone () const { return new NNLayer(*this); }
00073 
00074     UINT size () const { return n_output(); }
00075 
00076     void set_weight_range (REAL min, REAL max) {
00077         assert(min < max);
00078         w_min = min; w_max = max;
00079     }
00080     const WVEC& weight () const { return w; }
00081     void set_weight (const WVEC&);
00082 
00083     const WVEC& gradient () const { return dw; }
00084     void clear_gradient ();
00085 
00086     virtual void initialize ();
00087     virtual REAL train () { OBJ_FUNC_UNDEFINED("train"); }
00088     virtual Output operator() (const Input& x) const {
00089         Output y(n_output());
00090         feed_forward(x, y);
00091         return y;
00092     }
00093 
00094     void feed_forward (const Input&, Output&) const;
00095     void back_propagate (const Input&, const DVEC&, DVEC&);
00096 
00097 protected:
00098     virtual REAL sigmoid (REAL) const;
00099     virtual REAL sigmoid_deriv (REAL) const;
00100     virtual bool serialize (std::ostream&, ver_list&) const;
00101     virtual bool unserialize (std::istream&, ver_list&,
00102                               const id_t& = empty_id);
00103 };
00104 
00105 } // namespace lemga
00106 
00107 #ifdef  __NNLAYER_H__
00108 #warning "This header file may conflict with another `nnlayer.h' file."
00109 #endif
00110 #define __NNLAYER_H__
00111 #endif

Generated on Mon Jan 9 23:43:24 2006 for LEMGA by  doxygen 1.4.6