IT++ Logo
gmm.h
Go to the documentation of this file.
00001 
00029 #ifndef GMM_H
00030 #define GMM_H
00031 
00032 #include <itpp/base/mat.h>
00033 
00034 
00035 namespace itpp
00036 {
00037 
00039 
00045 class GMM
00046 {
00047 public:
00048   GMM();
00049   GMM(int nomix, int dim);
00050   GMM(std::string filename);
00051   void init_from_vq(const vec &codebook, int dim);
00052   // void init(const vec &w_in, const vec &m_in, const vec &sigma_in);
00053   void init(const vec &w_in, const mat &m_in, const mat &sigma_in);
00054   void load(std::string filename);
00055   void save(std::string filename);
00056   void set_weight(const vec &weights, bool compflag = true);
00057   void set_weight(int i, double weight, bool compflag = true);
00058   void set_mean(const mat &m_in);
00059   void set_mean(const vec &means, bool compflag = true);
00060   void set_mean(int i, const vec &means, bool compflag = true);
00061   void set_covariance(const mat &sigma_in);
00062   void set_covariance(const vec &covariances, bool compflag = true);
00063   void set_covariance(int i, const vec &covariances, bool compflag = true);
00064   int get_no_mixtures();
00065   int get_no_gaussians() const { return M; }
00066   int get_dimension();
00067   vec get_weight();
00068   double get_weight(int i);
00069   vec get_mean();
00070   vec get_mean(int i);
00071   vec get_covariance();
00072   vec get_covariance(int i);
00073   void marginalize(int d_new);
00074   void join(const GMM &newgmm);
00075   void clear();
00076   double likelihood(const vec &x);
00077   double likelihood_aposteriori(const vec &x, int mixture);
00078   vec likelihood_aposteriori(const vec &x);
00079   vec draw_sample();
00080 protected:
00081   vec   m, sigma, w;
00082   int   M, d;
00083 private:
00084   void  compute_internals();
00085   vec   normweight, normexp;
00086 };
00087 
00088 inline void GMM::set_weight(const vec &weights, bool compflag) {w = weights; if (compflag) compute_internals(); }
00089 inline void GMM::set_weight(int i, double weight, bool compflag) {w(i) = weight; if (compflag) compute_internals(); }
00090 inline void GMM::set_mean(const vec &means, bool compflag) {m = means; if (compflag) compute_internals(); }
00091 inline void GMM::set_covariance(const vec &covariances, bool compflag) {sigma = covariances; if (compflag) compute_internals(); }
00092 inline int GMM::get_dimension() {return d;}
00093 inline vec GMM::get_weight() {return w;}
00094 inline double GMM::get_weight(int i) {return w(i);}
00095 inline vec GMM::get_mean() {return m;}
00096 inline vec GMM::get_mean(int i) {return m.mid(i*d, d);}
00097 inline vec GMM::get_covariance() {return sigma;}
00098 inline vec GMM::get_covariance(int i) {return sigma.mid(i*d, d);}
00099 
00100 GMM gmmtrain(Array<vec> &TrainingData, int M, int NOITER = 30, bool VERBOSE = true);
00101 
00103 
00104 } // namespace itpp
00105 
00106 #endif // #ifndef GMM_H
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
SourceForge Logo

Generated on Sat Jul 9 2011 15:21:33 for IT++ by Doxygen 1.7.4