feedforwardnn.h

Go to the documentation of this file.
00001 // -*- C++ -*-
00002 #ifndef __LEMGA_FEEDFORWARDNN_H__
00003 #define __LEMGA_FEEDFORWARDNN_H__
00004 
00011 #include <vector>
00012 #include "learnmodel.h"
00013 #include "nnlayer.h"
00014 
00015 namespace lemga {
00016 
00018 class FeedForwardNN : public LearnModel {
00019     void free_space ();
00020     void forward (const Input& x) const {
00021         assert(n_layer > 0 && x.size() == n_input());
00022         layer[1]->feed_forward(x, _y[1]);
00023         for (UINT i = 2; i <= n_layer; ++i)
00024             layer[i]->feed_forward(_y[i-1], _y[i]);
00025     }
00026 
00027 public:
00028     typedef std::vector<NNLayer::WVEC> WEIGHT;
00029     enum TRAIN_METHOD {
00030         GRADIENT_DESCENT,
00031         LINE_SEARCH,
00032         CONJUGATE_GRADIENT,
00033         WEIGHT_DECAY,
00034         ADAPTIVE_LEARNING_RATE
00035     };
00036 
00037 protected:
00038     UINT n_layer;                   
00039     std::vector<NNLayer*> layer;    
00040     mutable std::vector<Output> _y; 
00041     mutable std::vector<Output> _dy;
00042 
00043     bool online_learn;
00044     TRAIN_METHOD train_method;
00045     REAL learn_rate, min_cst;
00046     UINT max_run;
00047 
00048 public:
00049     FeedForwardNN ();
00050     FeedForwardNN (const FeedForwardNN&);
00051     explicit FeedForwardNN (std::istream& is) { is >> *this; }
00052     virtual ~FeedForwardNN ();
00053     const FeedForwardNN& operator= (const FeedForwardNN&);
00054 
00055     virtual const id_t& id () const;
00056     virtual FeedForwardNN* create () const { return new FeedForwardNN(); }
00057     virtual FeedForwardNN* clone () const {
00058         return new FeedForwardNN(*this); }
00059 
00060     UINT size () const { return n_layer; }
00061     const NNLayer& operator[] (UINT n) const { return *layer[n+1]; }
00062     void add_top (const NNLayer&);
00063     void add_bottom (const NNLayer&);
00064 
00065     void set_batch_mode (bool b = true) { online_learn = !b; }
00066     void set_train_method (TRAIN_METHOD m) { train_method = m; }
00072     void set_parameter (REAL lr, REAL mincst, UINT maxrun) {
00073         learn_rate = lr; min_cst = mincst; max_run = maxrun; }
00074 
00075     virtual bool support_weighted_data () const { return true; }
00076     virtual void initialize ();
00077     virtual REAL train ();
00078     virtual Output operator() (const Input&) const;
00079 
00080 protected:
00081     virtual bool serialize (std::ostream&, ver_list&) const;
00082     virtual bool unserialize (std::istream&, ver_list&,
00083                               const id_t& = empty_id);
00084 
00085     virtual REAL _cost (const Output& F, const Output& y) const {
00086         return r_error(F, y); }
00087     virtual Output _cost_deriv (const Output& F, const Output& y) const;
00088     virtual void log_cost (UINT epoch, REAL err);
00089 
00090 public:
00091     WEIGHT weight () const;
00092     void set_weight (const WEIGHT&);
00093 
00094     REAL cost (UINT idx) const;
00095     REAL cost () const;
00096     WEIGHT gradient (UINT idx) const;
00097     WEIGHT gradient () const;
00098     void clear_gradient () const;
00099 
00100     bool stop_opt (UINT step, REAL cst);
00101 };
00102 
00103 } // namespace lemga
00104 
00105 #ifdef  __FEEDFORWARDNN_H__
00106 #warning "This header file may conflict with another `feedforwardnn.h' file."
00107 #endif
00108 #define __FEEDFORWARDNN_H__
00109 #endif

Generated on Mon Jan 9 23:43:24 2006 for LEMGA by  doxygen 1.4.6