00001
00002 #ifndef __LEMGA_FEEDFORWARDNN_H__
00003 #define __LEMGA_FEEDFORWARDNN_H__
00004
00011 #include <vector>
00012 #include "learnmodel.h"
00013 #include "nnlayer.h"
00014
00015 namespace lemga {
00016
00018 class FeedForwardNN : public LearnModel {
00019 void free_space ();
00020 void forward (const Input& x) const {
00021 assert(n_layer > 0 && x.size() == n_input());
00022 layer[1]->feed_forward(x, _y[1]);
00023 for (UINT i = 2; i <= n_layer; ++i)
00024 layer[i]->feed_forward(_y[i-1], _y[i]);
00025 }
00026
00027 public:
00028 typedef std::vector<NNLayer::WVEC> WEIGHT;
00029 enum TRAIN_METHOD {
00030 GRADIENT_DESCENT,
00031 LINE_SEARCH,
00032 CONJUGATE_GRADIENT,
00033 WEIGHT_DECAY,
00034 ADAPTIVE_LEARNING_RATE
00035 };
00036
00037 protected:
00038 UINT n_layer;
00039 std::vector<NNLayer*> layer;
00040 mutable std::vector<Output> _y;
00041 mutable std::vector<Output> _dy;
00042
00043 bool online_learn;
00044 TRAIN_METHOD train_method;
00045 REAL learn_rate, min_cst;
00046 UINT max_run;
00047
00048 public:
00049 FeedForwardNN ();
00050 FeedForwardNN (const FeedForwardNN&);
00051 explicit FeedForwardNN (std::istream& is) { is >> *this; }
00052 virtual ~FeedForwardNN ();
00053 const FeedForwardNN& operator= (const FeedForwardNN&);
00054
00055 virtual const id_t& id () const;
00056 virtual FeedForwardNN* create () const { return new FeedForwardNN(); }
00057 virtual FeedForwardNN* clone () const {
00058 return new FeedForwardNN(*this); }
00059
00060 UINT size () const { return n_layer; }
00061 const NNLayer& operator[] (UINT n) const { return *layer[n+1]; }
00062 void add_top (const NNLayer&);
00063 void add_bottom (const NNLayer&);
00064
00065 void set_batch_mode (bool b = true) { online_learn = !b; }
00066 void set_train_method (TRAIN_METHOD m) { train_method = m; }
00072 void set_parameter (REAL lr, REAL mincst, UINT maxrun) {
00073 learn_rate = lr; min_cst = mincst; max_run = maxrun; }
00074
00075 virtual bool support_weighted_data () const { return true; }
00076 virtual void initialize ();
00077 virtual REAL train ();
00078 virtual Output operator() (const Input&) const;
00079
00080 protected:
00081 virtual bool serialize (std::ostream&, ver_list&) const;
00082 virtual bool unserialize (std::istream&, ver_list&,
00083 const id_t& = empty_id);
00084
00085 virtual REAL _cost (const Output& F, const Output& y) const {
00086 return r_error(F, y); }
00087 virtual Output _cost_deriv (const Output& F, const Output& y) const;
00088 virtual void log_cost (UINT epoch, REAL err);
00089
00090 public:
00091 WEIGHT weight () const;
00092 void set_weight (const WEIGHT&);
00093
00094 REAL cost (UINT idx) const;
00095 REAL cost () const;
00096 WEIGHT gradient (UINT idx) const;
00097 WEIGHT gradient () const;
00098 void clear_gradient () const;
00099
00100 bool stop_opt (UINT step, REAL cst);
00101 };
00102
00103 }
00104
00105 #ifdef __FEEDFORWARDNN_H__
00106 #warning "This header file may conflict with another `feedforwardnn.h' file."
00107 #endif
00108 #define __FEEDFORWARDNN_H__
00109 #endif