00001 00030 #include <itpp/stat/mog_diag_em.h> 00031 #include <itpp/base/math/log_exp.h> 00032 #include <itpp/base/timing.h> 00033 00034 #include <iostream> 00035 #include <iomanip> 00036 00037 namespace itpp 00038 { 00039 00041 void inline MOG_diag_EM_sup::update_internals() 00042 { 00043 00044 double Ddiv2_log_2pi = D / 2.0 * std::log(m_2pi); 00045 00046 for (int k = 0;k < K;k++) c_log_weights[k] = std::log(c_weights[k]); 00047 00048 for (int k = 0;k < K;k++) { 00049 double acc = 0.0; 00050 double * c_diag_cov = c_diag_covs[k]; 00051 double * c_diag_cov_inv_etc = c_diag_covs_inv_etc[k]; 00052 00053 for (int d = 0;d < D;d++) { 00054 double tmp = c_diag_cov[d]; 00055 c_diag_cov_inv_etc[d] = 1.0 / (2.0 * tmp); 00056 acc += std::log(tmp); 00057 } 00058 00059 c_log_det_etc[k] = -Ddiv2_log_2pi - 0.5 * acc; 00060 } 00061 00062 } 00063 00064 00066 void inline MOG_diag_EM_sup::sanitise_params() 00067 { 00068 00069 double acc = 0.0; 00070 for (int k = 0;k < K;k++) { 00071 if (c_weights[k] < weight_floor) c_weights[k] = weight_floor; 00072 if (c_weights[k] > 1.0) c_weights[k] = 1.0; 00073 acc += c_weights[k]; 00074 } 00075 for (int k = 0;k < K;k++) c_weights[k] /= acc; 00076 00077 for (int k = 0;k < K;k++) 00078 for (int d = 0;d < D;d++) 00079 if (c_diag_covs[k][d] < var_floor) c_diag_covs[k][d] = var_floor; 00080 00081 } 00082 00084 double MOG_diag_EM_sup::ml_update_params() 00085 { 00086 00087 double acc_loglhood = 0.0; 00088 00089 for (int k = 0;k < K;k++) { 00090 c_acc_loglhood_K[k] = 0.0; 00091 00092 double * c_acc_mean = c_acc_means[k]; 00093 double * c_acc_cov = c_acc_covs[k]; 00094 00095 for (int d = 0;d < D;d++) { c_acc_mean[d] = 0.0; c_acc_cov[d] = 0.0; } 00096 } 00097 00098 for (int n = 0;n < N;n++) { 00099 double * c_x = c_X[n]; 00100 00101 bool danger = paranoid; 00102 for (int k = 0;k < K;k++) { 00103 double tmp = c_log_weights[k] + MOG_diag::log_lhood_single_gaus_internal(c_x, k); 00104 c_tmpvecK[k] = tmp; 00105 if (tmp >= log_max_K) danger = true; 00106 } 00107 00108 if (danger) { 00109 00110 double log_sum = c_tmpvecK[0]; 00111 for (int k = 1;k < K;k++) log_sum = log_add(log_sum, c_tmpvecK[k]); 00112 acc_loglhood += log_sum; 00113 00114 for (int k = 0;k < K;k++) { 00115 00116 double * c_acc_mean = c_acc_means[k]; 00117 double * c_acc_cov = c_acc_covs[k]; 00118 00119 double tmp_k = trunc_exp(c_tmpvecK[k] - log_sum); 00120 acc_loglhood_K[k] += tmp_k; 00121 00122 for (int d = 0;d < D;d++) { 00123 double tmp_x = c_x[d]; 00124 c_acc_mean[d] += tmp_k * tmp_x; 00125 c_acc_cov[d] += tmp_k * tmp_x * tmp_x; 00126 } 00127 } 00128 } 00129 else { 00130 00131 double sum = 0.0; 00132 for (int k = 0;k < K;k++) { double tmp = std::exp(c_tmpvecK[k]); c_tmpvecK[k] = tmp; sum += tmp; } 00133 acc_loglhood += std::log(sum); 00134 00135 for (int k = 0;k < K;k++) { 00136 00137 double * c_acc_mean = c_acc_means[k]; 00138 double * c_acc_cov = c_acc_covs[k]; 00139 00140 double tmp_k = c_tmpvecK[k] / sum; 00141 c_acc_loglhood_K[k] += tmp_k; 00142 00143 for (int d = 0;d < D;d++) { 00144 double tmp_x = c_x[d]; 00145 c_acc_mean[d] += tmp_k * tmp_x; 00146 c_acc_cov[d] += tmp_k * tmp_x * tmp_x; 00147 } 00148 } 00149 } 00150 } 00151 00152 for (int k = 0;k < K;k++) { 00153 00154 double * c_mean = c_means[k]; 00155 double * c_diag_cov = c_diag_covs[k]; 00156 00157 double * c_acc_mean = c_acc_means[k]; 00158 double * c_acc_cov = c_acc_covs[k]; 00159 00160 double tmp_k = c_acc_loglhood_K[k]; 00161 00162 c_weights[k] = tmp_k / N; 00163 00164 for (int d = 0;d < D;d++) { 00165 double tmp_mean = c_acc_mean[d] / tmp_k; 00166 c_mean[d] = tmp_mean; 00167 c_diag_cov[d] = c_acc_cov[d] / tmp_k - tmp_mean * tmp_mean; 00168 } 00169 } 00170 00171 return(acc_loglhood / N); 00172 00173 } 00174 00175 00176 void MOG_diag_EM_sup::ml_iterate() 00177 { 00178 using std::cout; 00179 using std::endl; 00180 using std::setw; 00181 using std::showpos; 00182 using std::noshowpos; 00183 using std::scientific; 00184 using std::fixed; 00185 using std::flush; 00186 using std::setprecision; 00187 00188 double avg_log_lhood_old = -1.0 * std::numeric_limits<double>::max(); 00189 00190 Real_Timer tt; 00191 00192 if (verbose) { 00193 cout << "MOG_diag_EM_sup::ml_iterate()" << endl; 00194 cout << setw(14) << "iteration"; 00195 cout << setw(14) << "avg_loglhood"; 00196 cout << setw(14) << "delta"; 00197 cout << setw(10) << "toc"; 00198 cout << endl; 00199 } 00200 00201 for (int i = 0; i < max_iter; i++) { 00202 sanitise_params(); 00203 update_internals(); 00204 00205 if (verbose) tt.tic(); 00206 double avg_log_lhood_new = ml_update_params(); 00207 00208 if (verbose) { 00209 double delta = avg_log_lhood_new - avg_log_lhood_old; 00210 00211 cout << noshowpos << fixed; 00212 cout << setw(14) << i; 00213 cout << showpos << scientific << setprecision(3); 00214 cout << setw(14) << avg_log_lhood_new; 00215 cout << setw(14) << delta; 00216 cout << noshowpos << fixed; 00217 cout << setw(10) << tt.toc(); 00218 cout << endl << flush; 00219 } 00220 00221 if (avg_log_lhood_new <= avg_log_lhood_old) break; 00222 00223 avg_log_lhood_old = avg_log_lhood_new; 00224 } 00225 } 00226 00227 00228 void MOG_diag_EM_sup::ml(MOG_diag &model_in, Array<vec> &X_in, int max_iter_in, double var_floor_in, double weight_floor_in, bool verbose_in) 00229 { 00230 00231 it_assert(model_in.is_valid(), "MOG_diag_EM_sup::ml(): initial model not valid"); 00232 it_assert(check_array_uniformity(X_in), "MOG_diag_EM_sup::ml(): 'X' is empty or contains vectors of varying dimensionality"); 00233 it_assert((max_iter_in > 0), "MOG_diag_EM_sup::ml(): 'max_iter' needs to be greater than zero"); 00234 00235 verbose = verbose_in; 00236 00237 N = X_in.size(); 00238 00239 Array<vec> means_in = model_in.get_means(); 00240 Array<vec> diag_covs_in = model_in.get_diag_covs(); 00241 vec weights_in = model_in.get_weights(); 00242 00243 init(means_in, diag_covs_in, weights_in); 00244 00245 means_in.set_size(0); 00246 diag_covs_in.set_size(0); 00247 weights_in.set_size(0); 00248 00249 if (K > N) { 00250 it_warning("MOG_diag_EM_sup::ml(): WARNING: K > N"); 00251 } 00252 else { 00253 if (K > N / 10) { 00254 it_warning("MOG_diag_EM_sup::ml(): WARNING: K > N/10"); 00255 } 00256 } 00257 00258 var_floor = var_floor_in; 00259 weight_floor = weight_floor_in; 00260 00261 const double tiny = std::numeric_limits<double>::min(); 00262 if (var_floor < tiny) var_floor = tiny; 00263 if (weight_floor < tiny) weight_floor = tiny; 00264 if (weight_floor > 1.0 / K) weight_floor = 1.0 / K; 00265 00266 max_iter = max_iter_in; 00267 00268 tmpvecK.set_size(K); 00269 tmpvecD.set_size(D); 00270 acc_loglhood_K.set_size(K); 00271 00272 acc_means.set_size(K); 00273 for (int k = 0;k < K;k++) acc_means(k).set_size(D); 00274 acc_covs.set_size(K); 00275 for (int k = 0;k < K;k++) acc_covs(k).set_size(D); 00276 00277 c_X = enable_c_access(X_in); 00278 c_tmpvecK = enable_c_access(tmpvecK); 00279 c_tmpvecD = enable_c_access(tmpvecD); 00280 c_acc_loglhood_K = enable_c_access(acc_loglhood_K); 00281 c_acc_means = enable_c_access(acc_means); 00282 c_acc_covs = enable_c_access(acc_covs); 00283 00284 ml_iterate(); 00285 00286 model_in.init(means, diag_covs, weights); 00287 00288 disable_c_access(c_X); 00289 disable_c_access(c_tmpvecK); 00290 disable_c_access(c_tmpvecD); 00291 disable_c_access(c_acc_loglhood_K); 00292 disable_c_access(c_acc_means); 00293 disable_c_access(c_acc_covs); 00294 00295 00296 tmpvecK.set_size(0); 00297 tmpvecD.set_size(0); 00298 acc_loglhood_K.set_size(0); 00299 acc_means.set_size(0); 00300 acc_covs.set_size(0); 00301 00302 cleanup(); 00303 00304 } 00305 00306 void MOG_diag_EM_sup::map(MOG_diag &, MOG_diag &, Array<vec> &, int, double, 00307 double, double, bool) 00308 { 00309 it_error("MOG_diag_EM_sup::map(): not implemented yet"); 00310 } 00311 00312 00313 // 00314 // convenience functions 00315 00316 void MOG_diag_ML(MOG_diag &model_in, Array<vec> &X_in, int max_iter_in, double var_floor_in, double weight_floor_in, bool verbose_in) 00317 { 00318 MOG_diag_EM_sup EM; 00319 EM.ml(model_in, X_in, max_iter_in, var_floor_in, weight_floor_in, verbose_in); 00320 } 00321 00322 void MOG_diag_MAP(MOG_diag &, MOG_diag &, Array<vec> &, int, double, double, 00323 double, bool) 00324 { 00325 it_error("MOG_diag_MAP(): not implemented yet"); 00326 } 00327 00328 } 00329
Generated on Sat Jul 9 2011 15:21:33 for IT++ by Doxygen 1.7.4