00001 00029 #include <itpp/srccode/gmm.h> 00030 #include <itpp/srccode/vqtrain.h> 00031 #include <itpp/base/math/elem_math.h> 00032 #include <itpp/base/matfunc.h> 00033 #include <itpp/base/specmat.h> 00034 #include <itpp/base/random.h> 00035 #include <itpp/base/timing.h> 00036 #include <iostream> 00037 #include <fstream> 00038 00040 00041 namespace itpp 00042 { 00043 00044 GMM::GMM() 00045 { 00046 d = 0; 00047 M = 0; 00048 } 00049 00050 GMM::GMM(std::string filename) 00051 { 00052 load(filename); 00053 } 00054 00055 GMM::GMM(int M_in, int d_in) 00056 { 00057 M = M_in; 00058 d = d_in; 00059 m = zeros(M * d); 00060 sigma = zeros(M * d); 00061 w = 1. / M * ones(M); 00062 00063 for (int i = 0;i < M;i++) { 00064 w(i) = 1.0 / M; 00065 } 00066 compute_internals(); 00067 } 00068 00069 void GMM::init_from_vq(const vec &codebook, int dim) 00070 { 00071 00072 mat C(dim, dim); 00073 int i; 00074 vec v; 00075 00076 d = dim; 00077 M = codebook.length() / dim; 00078 00079 m = codebook; 00080 w = ones(M) / double(M); 00081 00082 C.clear(); 00083 for (i = 0;i < M;i++) { 00084 v = codebook.mid(i * d, d); 00085 C = C + outer_product(v, v); 00086 } 00087 C = 1. / M * C; 00088 sigma.set_length(M*d); 00089 for (i = 0;i < M;i++) { 00090 sigma.replace_mid(i*d, diag(C)); 00091 } 00092 00093 compute_internals(); 00094 } 00095 00096 void GMM::init(const vec &w_in, const mat &m_in, const mat &sigma_in) 00097 { 00098 int i, j; 00099 d = m_in.rows(); 00100 M = m_in.cols(); 00101 00102 m.set_length(M*d); 00103 sigma.set_length(M*d); 00104 for (i = 0;i < M;i++) { 00105 for (j = 0;j < d;j++) { 00106 m(i*d + j) = m_in(j, i); 00107 sigma(i*d + j) = sigma_in(j, i); 00108 } 00109 } 00110 w = w_in; 00111 00112 compute_internals(); 00113 } 00114 00115 void GMM::set_mean(const mat &m_in) 00116 { 00117 int i, j; 00118 00119 d = m_in.rows(); 00120 M = m_in.cols(); 00121 00122 m.set_length(M*d); 00123 for (i = 0;i < M;i++) { 00124 for (j = 0;j < d;j++) { 00125 m(i*d + j) = m_in(j, i); 00126 } 00127 } 00128 compute_internals(); 00129 } 00130 00131 void GMM::set_mean(int i, const vec &means, bool compflag) 00132 { 00133 m.replace_mid(i*length(means), means); 00134 if (compflag) compute_internals(); 00135 } 00136 00137 void GMM::set_covariance(const mat &sigma_in) 00138 { 00139 int i, j; 00140 00141 d = sigma_in.rows(); 00142 M = sigma_in.cols(); 00143 00144 sigma.set_length(M*d); 00145 for (i = 0;i < M;i++) { 00146 for (j = 0;j < d;j++) { 00147 sigma(i*d + j) = sigma_in(j, i); 00148 } 00149 } 00150 compute_internals(); 00151 } 00152 00153 void GMM::set_covariance(int i, const vec &covariances, bool compflag) 00154 { 00155 sigma.replace_mid(i*length(covariances), covariances); 00156 if (compflag) compute_internals(); 00157 } 00158 00159 void GMM::marginalize(int d_new) 00160 { 00161 it_error_if(d_new > d, "GMM.marginalize: cannot change to a larger dimension"); 00162 00163 vec mnew(d_new*M), sigmanew(d_new*M); 00164 int i, j; 00165 00166 for (i = 0;i < M;i++) { 00167 for (j = 0;j < d_new;j++) { 00168 mnew(i*d_new + j) = m(i * d + j); 00169 sigmanew(i*d_new + j) = sigma(i * d + j); 00170 } 00171 } 00172 m = mnew; 00173 sigma = sigmanew; 00174 d = d_new; 00175 00176 compute_internals(); 00177 } 00178 00179 void GMM::join(const GMM &newgmm) 00180 { 00181 if (d == 0) { 00182 w = newgmm.w; 00183 m = newgmm.m; 00184 sigma = newgmm.sigma; 00185 d = newgmm.d; 00186 M = newgmm.M; 00187 } 00188 else { 00189 it_error_if(d != newgmm.d, "GMM.join: cannot join GMMs of different dimension"); 00190 00191 w = concat(double(M) / (M + newgmm.M) * w, double(newgmm.M) / (M + newgmm.M) * newgmm.w); 00192 w = w / sum(w); 00193 m = concat(m, newgmm.m); 00194 sigma = concat(sigma, newgmm.sigma); 00195 00196 M = M + newgmm.M; 00197 } 00198 compute_internals(); 00199 } 00200 00201 void GMM::clear() 00202 { 00203 w.set_length(0); 00204 m.set_length(0); 00205 sigma.set_length(0); 00206 d = 0; 00207 M = 0; 00208 } 00209 00210 void GMM::save(std::string filename) 00211 { 00212 std::ofstream f(filename.c_str()); 00213 int i, j; 00214 00215 f << M << " " << d << std::endl ; 00216 for (i = 0;i < w.length();i++) { 00217 f << w(i) << std::endl ; 00218 } 00219 for (i = 0;i < M;i++) { 00220 f << m(i*d) ; 00221 for (j = 1;j < d;j++) { 00222 f << " " << m(i*d + j) ; 00223 } 00224 f << std::endl ; 00225 } 00226 for (i = 0;i < M;i++) { 00227 f << sigma(i*d) ; 00228 for (j = 1;j < d;j++) { 00229 f << " " << sigma(i*d + j) ; 00230 } 00231 f << std::endl ; 00232 } 00233 } 00234 00235 void GMM::load(std::string filename) 00236 { 00237 std::ifstream GMMFile(filename.c_str()); 00238 int i, j; 00239 00240 it_error_if(!GMMFile, std::string("GMM::load : cannot open file ") + filename); 00241 00242 GMMFile >> M >> d ; 00243 00244 00245 w.set_length(M); 00246 for (i = 0;i < M;i++) { 00247 GMMFile >> w(i) ; 00248 } 00249 m.set_length(M*d); 00250 for (i = 0;i < M;i++) { 00251 for (j = 0;j < d;j++) { 00252 GMMFile >> m(i*d + j) ; 00253 } 00254 } 00255 sigma.set_length(M*d); 00256 for (i = 0;i < M;i++) { 00257 for (j = 0;j < d;j++) { 00258 GMMFile >> sigma(i*d + j) ; 00259 } 00260 } 00261 compute_internals(); 00262 std::cout << " mixtures:" << M << " dim:" << d << std::endl ; 00263 } 00264 00265 double GMM::likelihood(const vec &x) 00266 { 00267 double fx = 0; 00268 int i; 00269 00270 for (i = 0;i < M;i++) { 00271 fx += w(i) * likelihood_aposteriori(x, i); 00272 } 00273 return fx; 00274 } 00275 00276 vec GMM::likelihood_aposteriori(const vec &x) 00277 { 00278 vec v(M); 00279 int i; 00280 00281 for (i = 0;i < M;i++) { 00282 v(i) = w(i) * likelihood_aposteriori(x, i); 00283 } 00284 return v; 00285 } 00286 00287 double GMM::likelihood_aposteriori(const vec &x, int mixture) 00288 { 00289 int j; 00290 double s; 00291 00292 it_error_if(d != x.length(), "GMM::likelihood_aposteriori : dimensions does not match"); 00293 s = 0; 00294 for (j = 0;j < d;j++) { 00295 s += normexp(mixture * d + j) * sqr(x(j) - m(mixture * d + j)); 00296 } 00297 return normweight(mixture)*std::exp(s);; 00298 } 00299 00300 void GMM::compute_internals() 00301 { 00302 int i, j; 00303 double s; 00304 double constant = 1.0 / std::pow(2 * pi, d / 2.0); 00305 00306 normweight.set_length(M); 00307 normexp.set_length(M*d); 00308 00309 for (i = 0;i < M;i++) { 00310 s = 1; 00311 for (j = 0;j < d;j++) { 00312 normexp(i*d + j) = -0.5 / sigma(i * d + j); // check time 00313 s *= sigma(i * d + j); 00314 } 00315 normweight(i) = constant / std::sqrt(s); 00316 } 00317 00318 } 00319 00320 vec GMM::draw_sample() 00321 { 00322 static bool first = true; 00323 static vec cumweight; 00324 double u = randu(); 00325 int k; 00326 00327 if (first) { 00328 first = false; 00329 cumweight = cumsum(w); 00330 it_error_if(std::abs(cumweight(length(cumweight) - 1) - 1) > 1e-6, "weight does not sum to 0"); 00331 cumweight(length(cumweight) - 1) = 1; 00332 } 00333 k = 0; 00334 while (u > cumweight(k)) k++; 00335 00336 return elem_mult(sqrt(sigma.mid(k*d, d)), randn(d)) + m.mid(k*d, d); 00337 } 00338 00339 GMM gmmtrain(Array<vec> &TrainingData, int M, int NOITER, bool VERBOSE) 00340 { 00341 mat mean; 00342 int i, j, d = TrainingData(0).length(); 00343 vec sig; 00344 GMM gmm(M, d); 00345 vec m(d*M); 00346 vec sigma(d*M); 00347 vec w(M); 00348 vec normweight(M); 00349 vec normexp(d*M); 00350 double LL = 0, LLold, fx; 00351 double constant = 1.0 / std::pow(2 * pi, d / 2.0); 00352 int T = TrainingData.length(); 00353 vec x1; 00354 int t, n; 00355 vec msum(d*M); 00356 vec sigmasum(d*M); 00357 vec wsum(M); 00358 vec p_aposteriori(M); 00359 vec x2; 00360 double s; 00361 vec temp1, temp2; 00362 //double MINIMUM_VARIANCE=0.03; 00363 00364 //-----------initialization----------------------------------- 00365 00366 mean = vqtrain(TrainingData, M, 200000, 0.5, VERBOSE); 00367 for (i = 0;i < M;i++) gmm.set_mean(i, mean.get_col(i), false); 00368 // for (i=0;i<M;i++) gmm.set_mean(i,TrainingData(randi(0,TrainingData.length()-1)),false); 00369 sig = zeros(d); 00370 for (i = 0;i < TrainingData.length();i++) sig += sqr(TrainingData(i)); 00371 sig /= TrainingData.length(); 00372 for (i = 0;i < M;i++) gmm.set_covariance(i, 0.5*sig, false); 00373 00374 gmm.set_weight(1.0 / M*ones(M)); 00375 00376 //-----------optimization----------------------------------- 00377 00378 tic(); 00379 for (i = 0;i < M;i++) { 00380 temp1 = gmm.get_mean(i); 00381 temp2 = gmm.get_covariance(i); 00382 for (j = 0;j < d;j++) { 00383 m(i*d + j) = temp1(j); 00384 sigma(i*d + j) = temp2(j); 00385 } 00386 w(i) = gmm.get_weight(i); 00387 } 00388 for (n = 0;n < NOITER;n++) { 00389 for (i = 0;i < M;i++) { 00390 s = 1; 00391 for (j = 0;j < d;j++) { 00392 normexp(i*d + j) = -0.5 / sigma(i * d + j); // check time 00393 s *= sigma(i * d + j); 00394 } 00395 normweight(i) = constant * w(i) / std::sqrt(s); 00396 } 00397 LLold = LL; 00398 wsum.clear(); 00399 msum.clear(); 00400 sigmasum.clear(); 00401 LL = 0; 00402 for (t = 0;t < T;t++) { 00403 x1 = TrainingData(t); 00404 x2 = sqr(x1); 00405 fx = 0; 00406 for (i = 0;i < M;i++) { 00407 s = 0; 00408 for (j = 0;j < d;j++) { 00409 s += normexp(i * d + j) * sqr(x1(j) - m(i * d + j)); 00410 } 00411 p_aposteriori(i) = normweight(i) * std::exp(s); 00412 fx += p_aposteriori(i); 00413 } 00414 p_aposteriori /= fx; 00415 LL = LL + std::log(fx); 00416 00417 for (i = 0;i < M;i++) { 00418 wsum(i) += p_aposteriori(i); 00419 for (j = 0;j < d;j++) { 00420 msum(i*d + j) += p_aposteriori(i) * x1(j); 00421 sigmasum(i*d + j) += p_aposteriori(i) * x2(j); 00422 } 00423 } 00424 } 00425 for (i = 0;i < M;i++) { 00426 for (j = 0;j < d;j++) { 00427 m(i*d + j) = msum(i * d + j) / wsum(i); 00428 sigma(i*d + j) = sigmasum(i * d + j) / wsum(i) - sqr(m(i * d + j)); 00429 } 00430 w(i) = wsum(i) / T; 00431 } 00432 LL = LL / T; 00433 00434 if (std::abs((LL - LLold) / LL) < 1e-6) break; 00435 if (VERBOSE) { 00436 std::cout << n << ": " << LL << " " << std::abs((LL - LLold) / LL) << " " << toc() << std::endl ; 00437 std::cout << "---------------------------------------" << std::endl ; 00438 tic(); 00439 } 00440 else { 00441 std::cout << n << ": LL = " << LL << " " << std::abs((LL - LLold) / LL) << "\r" ; 00442 std::cout.flush(); 00443 } 00444 } 00445 for (i = 0;i < M;i++) { 00446 gmm.set_mean(i, m.mid(i*d, d), false); 00447 gmm.set_covariance(i, sigma.mid(i*d, d), false); 00448 } 00449 gmm.set_weight(w); 00450 return gmm; 00451 } 00452 00453 } // namespace itpp 00454
Generated on Sat Jul 9 2011 15:21:33 for IT++ by Doxygen 1.7.4