00001 00029 #include <itpp/base/math/log_exp.h> 00030 #include <itpp/stat/mog_diag.h> 00031 #include <cstdlib> 00032 00033 00034 namespace itpp 00035 { 00036 00037 double MOG_diag::log_lhood_single_gaus_internal(const double * c_x_in, const int k) const 00038 { 00039 00040 const double * c_mean = c_means[k]; 00041 const double * c_diag_cov_inv_etc = c_diag_covs_inv_etc[k]; 00042 00043 double acc = 0.0; 00044 00045 for (int d = 0; d < D; d++) { 00046 double tmp_val = c_x_in[d] - c_mean[d]; 00047 acc += (tmp_val * tmp_val) * c_diag_cov_inv_etc[d]; 00048 } 00049 return(c_log_det_etc[k] - acc); 00050 } 00051 00052 00053 double MOG_diag::log_lhood_single_gaus_internal(const vec &x_in, const int k) const 00054 { 00055 return log_lhood_single_gaus_internal(x_in._data(), k); 00056 } 00057 00058 00059 double MOG_diag::log_lhood_single_gaus(const double * c_x_in, const int k) const 00060 { 00061 if (do_checks) { 00062 it_assert(valid, "MOG_diag::log_lhood_single_gaus(): model not valid"); 00063 it_assert(((k >= 0) && (k < K)), "MOG::log_lhood_single_gaus(): k specifies a non-existant Gaussian"); 00064 } 00065 return log_lhood_single_gaus_internal(c_x_in, k); 00066 } 00067 00068 00069 double MOG_diag::log_lhood_single_gaus(const vec &x_in, const int k) const 00070 { 00071 if (do_checks) { 00072 it_assert(valid, "MOG_diag::log_lhood_single_gaus(): model not valid"); 00073 it_assert(check_size(x_in), "MOG_diag::log_lhood_single_gaus(): x has wrong dimensionality"); 00074 it_assert(((k >= 0) && (k < K)), "MOG::log_lhood_single_gaus(): k specifies a non-existant Gaussian"); 00075 } 00076 return log_lhood_single_gaus_internal(x_in._data(), k); 00077 } 00078 00079 00080 double MOG_diag::log_lhood_internal(const double * c_x_in) 00081 { 00082 00083 bool danger = paranoid; 00084 00085 for (int k = 0;k < K;k++) { 00086 double tmp = c_log_weights[k] + log_lhood_single_gaus_internal(c_x_in, k); 00087 c_tmpvecK[k] = tmp; 00088 00089 if (tmp >= log_max_K) danger = true; 00090 } 00091 00092 00093 if (danger) { 00094 double log_sum = c_tmpvecK[0]; 00095 for (int k = 1; k < K; k++) log_sum = log_add(log_sum, c_tmpvecK[k]); 00096 return(log_sum); 00097 } 00098 else { 00099 double sum = 0.0; 00100 for (int k = 0;k < K;k++) sum += std::exp(c_tmpvecK[k]); 00101 return(std::log(sum)); 00102 } 00103 } 00104 00105 00106 double MOG_diag::log_lhood_internal(const vec &x_in) 00107 { 00108 return log_lhood_internal(x_in._data()); 00109 } 00110 00111 00112 double MOG_diag::log_lhood(const vec &x_in) 00113 { 00114 if (do_checks) { 00115 it_assert(valid, "MOG_diag::log_lhood(): model not valid"); 00116 it_assert(check_size(x_in), "MOG_diag::log_lhood(): x has wrong dimensionality"); 00117 } 00118 return log_lhood_internal(x_in._data()); 00119 } 00120 00121 00122 double MOG_diag::log_lhood(const double * c_x_in) 00123 { 00124 if (do_checks) { 00125 it_assert(valid, "MOG_diag::log_lhood(): model not valid"); 00126 it_assert((c_x_in != 0), "MOG_diag::log_lhood(): c_x_in is a null pointer"); 00127 } 00128 00129 return log_lhood_internal(c_x_in); 00130 } 00131 00132 00133 double MOG_diag::lhood_internal(const double * c_x_in) 00134 { 00135 00136 bool danger = paranoid; 00137 00138 for (int k = 0;k < K;k++) { 00139 double tmp = c_log_weights[k] + log_lhood_single_gaus_internal(c_x_in, k); 00140 c_tmpvecK[k] = tmp; 00141 00142 if (tmp >= log_max_K) danger = true; 00143 } 00144 00145 00146 if (danger) { 00147 double log_sum = c_tmpvecK[0]; 00148 for (int k = 1; k < K; k++) log_sum = log_add(log_sum, c_tmpvecK[k]); 00149 return(trunc_exp(log_sum)); 00150 } 00151 else { 00152 double sum = 0.0; 00153 for (int k = 0;k < K;k++) sum += std::exp(c_tmpvecK[k]); 00154 return(sum); 00155 } 00156 } 00157 00158 double MOG_diag::lhood_internal(const vec &x_in) { return lhood_internal(x_in._data()); } 00159 00160 double MOG_diag::lhood(const vec &x_in) 00161 { 00162 if (do_checks) { 00163 it_assert(valid, "MOG_diag::lhood(): model not valid"); 00164 it_assert(check_size(x_in), "MOG_diag::lhood(): x has wrong dimensionality"); 00165 } 00166 return lhood_internal(x_in._data()); 00167 } 00168 00169 00170 double MOG_diag::lhood(const double * c_x_in) 00171 { 00172 if (do_checks) { 00173 it_assert(valid, "MOG_diag::lhood(): model not valid"); 00174 it_assert((c_x_in != 0), "MOG_diag::lhood(): c_x_in is a null pointer"); 00175 } 00176 00177 return lhood_internal(c_x_in); 00178 } 00179 00180 00181 double MOG_diag::avg_log_lhood(const double ** c_x_in, const int N) 00182 { 00183 if (do_checks) { 00184 it_assert(valid, "MOG_diag::avg_log_lhood(): model not valid"); 00185 it_assert((c_x_in != 0), "MOG_diag::avg_log_lhood(): c_x_in is a null pointer"); 00186 it_assert((N >= 0), "MOG_diag::avg_log_lhood(): N is zero or negative"); 00187 } 00188 00189 double acc = 0.0; 00190 for (int n = 0;n < N;n++) acc += log_lhood_internal(c_x_in[n]); 00191 return(acc / N); 00192 } 00193 00194 00195 double MOG_diag::avg_log_lhood(const Array<vec> &X_in) 00196 { 00197 if (do_checks) { 00198 it_assert(valid, "MOG_diag::avg_log_lhood(): model not valid"); 00199 it_assert(check_size(X_in), "MOG_diag::avg_log_lhood(): X is empty or at least one vector has the wrong dimensionality"); 00200 } 00201 const int N = X_in.size(); 00202 double acc = 0.0; 00203 for (int n = 0;n < N;n++) acc += log_lhood_internal(X_in(n)._data()); 00204 return(acc / N); 00205 } 00206 00207 void MOG_diag::zero_all_ptrs() 00208 { 00209 c_means = 0; 00210 c_diag_covs = 0; 00211 c_diag_covs_inv_etc = 0; 00212 c_weights = 0; 00213 c_log_weights = 0; 00214 c_log_det_etc = 0; 00215 c_tmpvecK = 0; 00216 } 00217 00218 00219 void MOG_diag::free_all_ptrs() 00220 { 00221 c_means = disable_c_access(c_means); 00222 c_diag_covs = disable_c_access(c_diag_covs); 00223 c_diag_covs_inv_etc = disable_c_access(c_diag_covs_inv_etc); 00224 c_weights = disable_c_access(c_weights); 00225 c_log_weights = disable_c_access(c_log_weights); 00226 c_log_det_etc = disable_c_access(c_log_det_etc); 00227 c_tmpvecK = disable_c_access(c_tmpvecK); 00228 } 00229 00230 00231 void MOG_diag::setup_means() 00232 { 00233 MOG_generic::setup_means(); 00234 disable_c_access(c_means); 00235 c_means = enable_c_access(means); 00236 } 00237 00238 00239 void MOG_diag::setup_covs() 00240 { 00241 MOG_generic::setup_covs(); 00242 if (full) return; 00243 00244 disable_c_access(c_diag_covs); 00245 disable_c_access(c_diag_covs_inv_etc); 00246 disable_c_access(c_log_det_etc); 00247 00248 c_diag_covs = enable_c_access(diag_covs); 00249 c_diag_covs_inv_etc = enable_c_access(diag_covs_inv_etc); 00250 c_log_det_etc = enable_c_access(log_det_etc); 00251 } 00252 00253 00254 void MOG_diag::setup_weights() 00255 { 00256 MOG_generic::setup_weights(); 00257 00258 disable_c_access(c_weights); 00259 disable_c_access(c_log_weights); 00260 00261 c_weights = enable_c_access(weights); 00262 c_log_weights = enable_c_access(log_weights); 00263 } 00264 00265 00266 void MOG_diag::setup_misc() 00267 { 00268 disable_c_access(c_tmpvecK); 00269 tmpvecK.set_size(K); 00270 c_tmpvecK = enable_c_access(tmpvecK); 00271 00272 MOG_generic::setup_misc(); 00273 if (full) convert_to_diag_internal(); 00274 } 00275 00276 00277 void MOG_diag::load(const std::string &name_in) 00278 { 00279 MOG_generic::load(name_in); 00280 if (full) convert_to_diag(); 00281 } 00282 00283 00284 double ** MOG_diag::enable_c_access(Array<vec> & A_in) 00285 { 00286 int rows = A_in.size(); 00287 double ** A = (double **)std::malloc(rows * sizeof(double *)); 00288 if (A) for (int row = 0;row < rows;row++) A[row] = A_in(row)._data(); 00289 return(A); 00290 } 00291 00292 int ** MOG_diag::enable_c_access(Array<ivec> & A_in) 00293 { 00294 int rows = A_in.size(); 00295 int ** A = (int **)std::malloc(rows * sizeof(int *)); 00296 if (A) for (int row = 0;row < rows;row++) A[row] = A_in(row)._data(); 00297 return(A); 00298 } 00299 00300 double ** MOG_diag::disable_c_access(double ** A_in) { if (A_in) std::free(A_in); return(0); } 00301 int ** MOG_diag::disable_c_access(int ** A_in) { if (A_in) std::free(A_in); return(0); } 00302 00303 double * MOG_diag::enable_c_access(vec & v_in) { return v_in._data(); } 00304 int * MOG_diag::enable_c_access(ivec & v_in) { return v_in._data(); } 00305 00306 double * MOG_diag::disable_c_access(double *) { return(0); } 00307 int * MOG_diag::disable_c_access(int *) { return(0); } 00308 00309 }
Generated on Sat Jul 9 2011 15:21:33 for IT++ by Doxygen 1.7.4