00001
00002 #ifndef __LEMGA_STUMP_H__
00003 #define __LEMGA_STUMP_H__
00004
00011 #include "learnmodel.h"
00012
00013 namespace lemga {
00014
00018 class Stump : public LearnModel {
00019 UINT idx;
00020 REAL bd1, bd2;
00021 bool dir;
00022 bool hard;
00023
00024 public:
00025 explicit Stump (UINT n_in = 0)
00026 : LearnModel(n_in, 1), idx(0), bd1(0), bd2(0), hard(true) {}
00027 explicit Stump (std::istream& is) { is >> *this; }
00028
00029 virtual const id_t& id () const;
00030 virtual Stump* create () const { return new Stump(); }
00031 virtual Stump* clone () const { return new Stump(*this); }
00032
00033 UINT index () const { return idx; }
00034 REAL threshold () const { return (bd1+bd2)/2; }
00035 bool direction () const { return dir; }
00036 bool soft_threshold () const { return !hard; }
00037 void use_soft_threshold (bool s = true) { hard = !s; }
00038
00039 virtual bool support_weighted_data () const { return true; }
00040
00041 virtual REAL train ();
00043 static REAL train_1d (const std::vector<REAL>&, const std::vector<REAL>&,
00044 REAL, bool&, bool&, REAL&, REAL&);
00046 static REAL train_1d (const std::vector<REAL>&, const std::vector<REAL>&);
00047
00048 virtual Output operator() (const Input&) const;
00049
00050 protected:
00051 virtual bool serialize (std::ostream&, ver_list&) const;
00052 virtual bool unserialize (std::istream&, ver_list&,
00053 const id_t& = empty_id);
00054 };
00055
00056 }
00057
00058 #ifdef __STUMP_H__
00059 #warning "This header file may conflict with another `stump.h' file."
00060 #endif
00061 #define __STUMP_H__
00062 #endif