00001 00029 #ifndef MOG_GENERIC_H 00030 #define MOG_GENERIC_H 00031 00032 #include <itpp/base/vec.h> 00033 #include <itpp/base/mat.h> 00034 #include <itpp/base/array.h> 00035 00036 00037 namespace itpp 00038 { 00039 00056 class MOG_generic 00057 { 00058 00059 public: 00060 00066 MOG_generic() { init(); } 00067 00071 MOG_generic(const std::string &name_in) { load(name_in); } 00072 00078 MOG_generic(const int &K_in, const int &D_in, bool full_in = false) { init(K_in, D_in, full_in); } 00079 00087 MOG_generic(Array<vec> &means_in, bool full_in = false) { init(means_in, full_in); } 00088 00095 MOG_generic(Array<vec> &means_in, Array<vec> &diag_covs_in, vec &weights_in) { init(means_in, diag_covs_in, weights_in); } 00096 00103 MOG_generic(Array<vec> &means_in, Array<mat> &full_covs_in, vec &weights_in) { init(means_in, full_covs_in, weights_in); } 00104 00106 virtual ~MOG_generic() { cleanup(); } 00107 00112 void init(); 00113 00119 void init(const int &K_in, const int &D_in, bool full_in = false); 00120 00128 void init(Array<vec> &means_in, bool full_in = false); 00129 00136 void init(Array<vec> &means_in, Array<vec> &diag_covs_in, vec &weights_in); 00137 00144 void init(Array<vec> &means_in, Array<mat> &full_covs_in, vec &weights_in); 00145 00150 virtual void cleanup(); 00151 00153 bool is_valid() const { return valid; } 00154 00156 bool is_full() const { return full; } 00157 00159 int get_K() const { if (valid) return(K); else return(0); } 00160 00162 int get_D() const { if (valid) return(D); else return(0); } 00163 00165 vec get_weights() const { vec tmp; if (valid) { tmp = weights; } return tmp; } 00166 00168 Array<vec> get_means() const { Array<vec> tmp; if (valid) { tmp = means; } return tmp; } 00169 00171 Array<vec> get_diag_covs() const { Array<vec> tmp; if (valid && !full) { tmp = diag_covs; } return tmp; } 00172 00174 Array<mat> get_full_covs() const { Array<mat> tmp; if (valid && full) { tmp = full_covs; } return tmp; } 00175 00179 void set_means(Array<vec> &means_in); 00180 00184 void set_diag_covs(Array<vec> &diag_covs_in); 00185 00189 void set_full_covs(Array<mat> &full_covs_in); 00190 00194 void set_weights(vec &weights_in); 00195 00197 void set_means_zero(); 00198 00200 void set_diag_covs_unity(); 00201 00203 void set_full_covs_unity(); 00204 00206 void set_weights_uniform(); 00207 00213 void set_checks(bool do_checks_in) { do_checks = do_checks_in; } 00214 00218 void set_paranoid(bool paranoid_in) { paranoid = paranoid_in; } 00219 00223 virtual void load(const std::string &name_in); 00224 00228 virtual void save(const std::string &name_in) const; 00229 00246 virtual void join(const MOG_generic &B_in); 00247 00255 virtual void convert_to_diag(); 00256 00262 virtual void convert_to_full(); 00263 00265 virtual double log_lhood_single_gaus(const vec &x_in, const int k); 00266 00268 virtual double log_lhood(const vec &x_in); 00269 00271 virtual double lhood(const vec &x_in); 00272 00274 virtual double avg_log_lhood(const Array<vec> &X_in); 00275 00276 protected: 00277 00279 bool do_checks; 00280 00282 bool valid; 00283 00285 bool full; 00286 00288 bool paranoid; 00289 00291 int K; 00292 00294 int D; 00295 00297 Array<vec> means; 00298 00300 Array<vec> diag_covs; 00301 00303 Array<mat> full_covs; 00304 00306 vec weights; 00307 00309 double log_max_K; 00310 00316 vec log_det_etc; 00317 00319 vec log_weights; 00320 00322 Array<mat> full_covs_inv; 00323 00325 Array<vec> diag_covs_inv_etc; 00326 00328 bool check_size(const vec &x_in) const; 00329 00331 bool check_size(const Array<vec> &X_in) const; 00332 00334 bool check_array_uniformity(const Array<vec> & A) const; 00335 00337 void set_means_internal(Array<vec> &means_in); 00339 void set_diag_covs_internal(Array<vec> &diag_covs_in); 00341 void set_full_covs_internal(Array<mat> &full_covs_in); 00343 void set_weights_internal(vec &_weigths); 00344 00346 void set_means_zero_internal(); 00348 void set_diag_covs_unity_internal(); 00350 void set_full_covs_unity_internal(); 00352 void set_weights_uniform_internal(); 00353 00355 void convert_to_diag_internal(); 00357 void convert_to_full_internal(); 00358 00360 virtual void setup_means(); 00361 00363 virtual void setup_covs(); 00364 00366 virtual void setup_weights(); 00367 00369 virtual void setup_misc(); 00370 00372 virtual double log_lhood_single_gaus_internal(const vec &x_in, const int k); 00374 virtual double log_lhood_internal(const vec &x_in); 00376 virtual double lhood_internal(const vec &x_in); 00377 00378 private: 00379 vec tmpvecD; 00380 vec tmpvecK; 00381 00382 }; 00383 00384 } // namespace itpp 00385 00386 #endif // #ifndef MOG_GENERIC_H
Generated on Sat Jul 9 2011 15:21:33 for IT++ by Doxygen 1.7.4