IT++ Logo
mog_diag_em.cpp
Go to the documentation of this file.
00001 
00030 #include <itpp/stat/mog_diag_em.h>
00031 #include <itpp/base/math/log_exp.h>
00032 #include <itpp/base/timing.h>
00033 
00034 #include <iostream>
00035 #include <iomanip>
00036 
00037 namespace itpp
00038 {
00039 
00041 void inline MOG_diag_EM_sup::update_internals()
00042 {
00043 
00044   double Ddiv2_log_2pi = D / 2.0 * std::log(m_2pi);
00045 
00046   for (int k = 0;k < K;k++)  c_log_weights[k] = std::log(c_weights[k]);
00047 
00048   for (int k = 0;k < K;k++) {
00049     double acc = 0.0;
00050     double * c_diag_cov = c_diag_covs[k];
00051     double * c_diag_cov_inv_etc = c_diag_covs_inv_etc[k];
00052 
00053     for (int d = 0;d < D;d++) {
00054       double tmp = c_diag_cov[d];
00055       c_diag_cov_inv_etc[d] = 1.0 / (2.0 * tmp);
00056       acc += std::log(tmp);
00057     }
00058 
00059     c_log_det_etc[k] = -Ddiv2_log_2pi - 0.5 * acc;
00060   }
00061 
00062 }
00063 
00064 
00066 void inline MOG_diag_EM_sup::sanitise_params()
00067 {
00068 
00069   double acc = 0.0;
00070   for (int k = 0;k < K;k++) {
00071     if (c_weights[k] < weight_floor)  c_weights[k] = weight_floor;
00072     if (c_weights[k] > 1.0)  c_weights[k] = 1.0;
00073     acc += c_weights[k];
00074   }
00075   for (int k = 0;k < K;k++)  c_weights[k] /= acc;
00076 
00077   for (int k = 0;k < K;k++)
00078     for (int d = 0;d < D;d++)
00079       if (c_diag_covs[k][d] < var_floor)  c_diag_covs[k][d] = var_floor;
00080 
00081 }
00082 
00084 double MOG_diag_EM_sup::ml_update_params()
00085 {
00086 
00087   double acc_loglhood = 0.0;
00088 
00089   for (int k = 0;k < K;k++)  {
00090     c_acc_loglhood_K[k] = 0.0;
00091 
00092     double * c_acc_mean = c_acc_means[k];
00093     double * c_acc_cov  = c_acc_covs[k];
00094 
00095     for (int d = 0;d < D;d++) { c_acc_mean[d] = 0.0; c_acc_cov[d] = 0.0; }
00096   }
00097 
00098   for (int n = 0;n < N;n++) {
00099     double * c_x =  c_X[n];
00100 
00101     bool danger = paranoid;
00102     for (int k = 0;k < K;k++)  {
00103       double tmp = c_log_weights[k] + MOG_diag::log_lhood_single_gaus_internal(c_x, k);
00104       c_tmpvecK[k] = tmp;
00105       if (tmp >= log_max_K)  danger = true;
00106     }
00107 
00108     if (danger) {
00109 
00110       double log_sum = c_tmpvecK[0];
00111       for (int k = 1;k < K;k++)  log_sum = log_add(log_sum, c_tmpvecK[k]);
00112       acc_loglhood += log_sum;
00113 
00114       for (int k = 0;k < K;k++) {
00115 
00116         double * c_acc_mean = c_acc_means[k];
00117         double * c_acc_cov = c_acc_covs[k];
00118 
00119         double tmp_k = trunc_exp(c_tmpvecK[k] - log_sum);
00120         acc_loglhood_K[k] += tmp_k;
00121 
00122         for (int d = 0;d < D;d++) {
00123           double tmp_x = c_x[d];
00124           c_acc_mean[d] +=  tmp_k * tmp_x;
00125           c_acc_cov[d] += tmp_k * tmp_x * tmp_x;
00126         }
00127       }
00128     }
00129     else {
00130 
00131       double sum = 0.0;
00132       for (int k = 0;k < K;k++) { double tmp = std::exp(c_tmpvecK[k]); c_tmpvecK[k] = tmp; sum += tmp; }
00133       acc_loglhood += std::log(sum);
00134 
00135       for (int k = 0;k < K;k++) {
00136 
00137         double * c_acc_mean = c_acc_means[k];
00138         double * c_acc_cov = c_acc_covs[k];
00139 
00140         double tmp_k = c_tmpvecK[k] / sum;
00141         c_acc_loglhood_K[k] += tmp_k;
00142 
00143         for (int d = 0;d < D;d++) {
00144           double tmp_x = c_x[d];
00145           c_acc_mean[d] +=  tmp_k * tmp_x;
00146           c_acc_cov[d] += tmp_k * tmp_x * tmp_x;
00147         }
00148       }
00149     }
00150   }
00151 
00152   for (int k = 0;k < K;k++) {
00153 
00154     double * c_mean = c_means[k];
00155     double * c_diag_cov = c_diag_covs[k];
00156 
00157     double * c_acc_mean = c_acc_means[k];
00158     double * c_acc_cov = c_acc_covs[k];
00159 
00160     double tmp_k = c_acc_loglhood_K[k];
00161 
00162     c_weights[k] = tmp_k / N;
00163 
00164     for (int d = 0;d < D;d++) {
00165       double tmp_mean = c_acc_mean[d] / tmp_k;
00166       c_mean[d] = tmp_mean;
00167       c_diag_cov[d] = c_acc_cov[d] / tmp_k - tmp_mean * tmp_mean;
00168     }
00169   }
00170 
00171   return(acc_loglhood / N);
00172 
00173 }
00174 
00175 
00176 void MOG_diag_EM_sup::ml_iterate()
00177 {
00178   using std::cout;
00179   using std::endl;
00180   using std::setw;
00181   using std::showpos;
00182   using std::noshowpos;
00183   using std::scientific;
00184   using std::fixed;
00185   using std::flush;
00186   using std::setprecision;
00187 
00188   double avg_log_lhood_old = -1.0 * std::numeric_limits<double>::max();
00189 
00190   Real_Timer tt;
00191 
00192   if (verbose) {
00193     cout << "MOG_diag_EM_sup::ml_iterate()" << endl;
00194     cout << setw(14) << "iteration";
00195     cout << setw(14) << "avg_loglhood";
00196     cout << setw(14) << "delta";
00197     cout << setw(10) << "toc";
00198     cout << endl;
00199   }
00200 
00201   for (int i = 0; i < max_iter; i++) {
00202     sanitise_params();
00203     update_internals();
00204 
00205     if (verbose) tt.tic();
00206     double avg_log_lhood_new = ml_update_params();
00207 
00208     if (verbose) {
00209       double delta = avg_log_lhood_new - avg_log_lhood_old;
00210 
00211       cout << noshowpos << fixed;
00212       cout << setw(14) << i;
00213       cout << showpos << scientific << setprecision(3);
00214       cout << setw(14) << avg_log_lhood_new;
00215       cout << setw(14) << delta;
00216       cout << noshowpos << fixed;
00217       cout << setw(10) << tt.toc();
00218       cout << endl << flush;
00219     }
00220 
00221     if (avg_log_lhood_new <= avg_log_lhood_old)  break;
00222 
00223     avg_log_lhood_old = avg_log_lhood_new;
00224   }
00225 }
00226 
00227 
00228 void MOG_diag_EM_sup::ml(MOG_diag &model_in, Array<vec> &X_in, int max_iter_in, double var_floor_in, double weight_floor_in, bool verbose_in)
00229 {
00230 
00231   it_assert(model_in.is_valid(), "MOG_diag_EM_sup::ml(): initial model not valid");
00232   it_assert(check_array_uniformity(X_in), "MOG_diag_EM_sup::ml(): 'X' is empty or contains vectors of varying dimensionality");
00233   it_assert((max_iter_in > 0), "MOG_diag_EM_sup::ml(): 'max_iter' needs to be greater than zero");
00234 
00235   verbose = verbose_in;
00236 
00237   N = X_in.size();
00238 
00239   Array<vec> means_in = model_in.get_means();
00240   Array<vec> diag_covs_in = model_in.get_diag_covs();
00241   vec weights_in = model_in.get_weights();
00242 
00243   init(means_in, diag_covs_in, weights_in);
00244 
00245   means_in.set_size(0);
00246   diag_covs_in.set_size(0);
00247   weights_in.set_size(0);
00248 
00249   if (K > N) {
00250     it_warning("MOG_diag_EM_sup::ml(): WARNING: K > N");
00251   }
00252   else {
00253     if (K > N / 10) {
00254       it_warning("MOG_diag_EM_sup::ml(): WARNING: K > N/10");
00255     }
00256   }
00257 
00258   var_floor = var_floor_in;
00259   weight_floor = weight_floor_in;
00260 
00261   const double tiny = std::numeric_limits<double>::min();
00262   if (var_floor < tiny) var_floor = tiny;
00263   if (weight_floor < tiny) weight_floor = tiny;
00264   if (weight_floor > 1.0 / K) weight_floor = 1.0 / K;
00265 
00266   max_iter = max_iter_in;
00267 
00268   tmpvecK.set_size(K);
00269   tmpvecD.set_size(D);
00270   acc_loglhood_K.set_size(K);
00271 
00272   acc_means.set_size(K);
00273   for (int k = 0;k < K;k++) acc_means(k).set_size(D);
00274   acc_covs.set_size(K);
00275   for (int k = 0;k < K;k++) acc_covs(k).set_size(D);
00276 
00277   c_X = enable_c_access(X_in);
00278   c_tmpvecK = enable_c_access(tmpvecK);
00279   c_tmpvecD = enable_c_access(tmpvecD);
00280   c_acc_loglhood_K = enable_c_access(acc_loglhood_K);
00281   c_acc_means = enable_c_access(acc_means);
00282   c_acc_covs = enable_c_access(acc_covs);
00283 
00284   ml_iterate();
00285 
00286   model_in.init(means, diag_covs, weights);
00287 
00288   disable_c_access(c_X);
00289   disable_c_access(c_tmpvecK);
00290   disable_c_access(c_tmpvecD);
00291   disable_c_access(c_acc_loglhood_K);
00292   disable_c_access(c_acc_means);
00293   disable_c_access(c_acc_covs);
00294 
00295 
00296   tmpvecK.set_size(0);
00297   tmpvecD.set_size(0);
00298   acc_loglhood_K.set_size(0);
00299   acc_means.set_size(0);
00300   acc_covs.set_size(0);
00301 
00302   cleanup();
00303 
00304 }
00305 
00306 void MOG_diag_EM_sup::map(MOG_diag &, MOG_diag &, Array<vec> &, int, double,
00307                           double, double, bool)
00308 {
00309   it_error("MOG_diag_EM_sup::map(): not implemented yet");
00310 }
00311 
00312 
00313 //
00314 // convenience functions
00315 
00316 void MOG_diag_ML(MOG_diag &model_in, Array<vec> &X_in, int max_iter_in, double var_floor_in, double weight_floor_in, bool verbose_in)
00317 {
00318   MOG_diag_EM_sup EM;
00319   EM.ml(model_in, X_in, max_iter_in, var_floor_in, weight_floor_in, verbose_in);
00320 }
00321 
00322 void MOG_diag_MAP(MOG_diag &, MOG_diag &, Array<vec> &, int, double, double,
00323                   double, bool)
00324 {
00325   it_error("MOG_diag_MAP(): not implemented yet");
00326 }
00327 
00328 }
00329 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
SourceForge Logo

Generated on Sat Jul 9 2011 15:21:33 for IT++ by Doxygen 1.7.4