IT++ Logo
ls_solve.cpp
Go to the documentation of this file.
00001 
00029 #ifndef _MSC_VER
00030 #  include <itpp/config.h>
00031 #else
00032 #  include <itpp/config_msvc.h>
00033 #endif
00034 
00035 #if defined(HAVE_LAPACK)
00036 #  include <itpp/base/algebra/lapack.h>
00037 #endif
00038 
00039 #include <itpp/base/algebra/ls_solve.h>
00040 
00041 
00042 namespace itpp
00043 {
00044 
00045 // ----------- ls_solve_chol -----------------------------------------------------------
00046 
00047 #if defined(HAVE_LAPACK)
00048 
00049 bool ls_solve_chol(const mat &A, const vec &b, vec &x)
00050 {
00051   int n, lda, ldb, nrhs, info;
00052   n = lda = ldb = A.rows();
00053   nrhs = 1;
00054   char uplo = 'U';
00055 
00056   it_assert_debug(A.cols() == n, "ls_solve_chol: System-matrix is not square");
00057   it_assert_debug(n == b.size(), "The number of rows in A must equal the length of b!");
00058 
00059   ivec ipiv(n);
00060   x = b;
00061   mat Chol = A;
00062 
00063   dposv_(&uplo, &n, &nrhs, Chol._data(), &lda, x._data(), &ldb, &info);
00064 
00065   return (info == 0);
00066 }
00067 
00068 
00069 bool ls_solve_chol(const mat &A, const mat &B, mat &X)
00070 {
00071   int n, lda, ldb, nrhs, info;
00072   n = lda = ldb = A.rows();
00073   nrhs = B.cols();
00074   char uplo = 'U';
00075 
00076   it_assert_debug(A.cols() == n, "ls_solve_chol: System-matrix is not square");
00077   it_assert_debug(n == B.rows(), "The number of rows in A must equal the length of B!");
00078 
00079   ivec ipiv(n);
00080   X = B;
00081   mat Chol = A;
00082 
00083   dposv_(&uplo, &n, &nrhs, Chol._data(), &lda, X._data(), &ldb, &info);
00084 
00085   return (info == 0);
00086 }
00087 
00088 bool ls_solve_chol(const cmat &A, const cvec &b, cvec &x)
00089 {
00090   int n, lda, ldb, nrhs, info;
00091   n = lda = ldb = A.rows();
00092   nrhs = 1;
00093   char uplo = 'U';
00094 
00095   it_assert_debug(A.cols() == n, "ls_solve_chol: System-matrix is not square");
00096   it_assert_debug(n == b.size(), "The number of rows in A must equal the length of b!");
00097 
00098   ivec ipiv(n);
00099   x = b;
00100   cmat Chol = A;
00101 
00102   zposv_(&uplo, &n, &nrhs, Chol._data(), &lda, x._data(), &ldb, &info);
00103 
00104   return (info == 0);
00105 }
00106 
00107 bool ls_solve_chol(const cmat &A, const cmat &B, cmat &X)
00108 {
00109   int n, lda, ldb, nrhs, info;
00110   n = lda = ldb = A.rows();
00111   nrhs = B.cols();
00112   char uplo = 'U';
00113 
00114   it_assert_debug(A.cols() == n, "ls_solve_chol: System-matrix is not square");
00115   it_assert_debug(n == B.rows(), "The number of rows in A must equal the length of B!");
00116 
00117   ivec ipiv(n);
00118   X = B;
00119   cmat Chol = A;
00120 
00121   zposv_(&uplo, &n, &nrhs, Chol._data(), &lda, X._data(), &ldb, &info);
00122 
00123   return (info == 0);
00124 }
00125 
00126 #else
00127 
00128 bool ls_solve_chol(const mat &A, const vec &b, vec &x)
00129 {
00130   it_error("LAPACK library is needed to use ls_solve_chol() function");
00131   return false;
00132 }
00133 
00134 bool ls_solve_chol(const mat &A, const mat &B, mat &X)
00135 {
00136   it_error("LAPACK library is needed to use ls_solve_chol() function");
00137   return false;
00138 }
00139 
00140 bool ls_solve_chol(const cmat &A, const cvec &b, cvec &x)
00141 {
00142   it_error("LAPACK library is needed to use ls_solve_chol() function");
00143   return false;
00144 }
00145 
00146 bool ls_solve_chol(const cmat &A, const cmat &B, cmat &X)
00147 {
00148   it_error("LAPACK library is needed to use ls_solve_chol() function");
00149   return false;
00150 }
00151 
00152 #endif // HAVE_LAPACK
00153 
00154 vec ls_solve_chol(const mat &A, const vec &b)
00155 {
00156   vec x;
00157   bool info;
00158   info = ls_solve_chol(A, b, x);
00159   it_assert_debug(info, "ls_solve_chol: Failed solving the system");
00160   return x;
00161 }
00162 
00163 mat ls_solve_chol(const mat &A, const mat &B)
00164 {
00165   mat X;
00166   bool info;
00167   info = ls_solve_chol(A, B, X);
00168   it_assert_debug(info, "ls_solve_chol: Failed solving the system");
00169   return X;
00170 }
00171 
00172 cvec ls_solve_chol(const cmat &A, const cvec &b)
00173 {
00174   cvec x;
00175   bool info;
00176   info = ls_solve_chol(A, b, x);
00177   it_assert_debug(info, "ls_solve_chol: Failed solving the system");
00178   return x;
00179 }
00180 
00181 cmat ls_solve_chol(const cmat &A, const cmat &B)
00182 {
00183   cmat X;
00184   bool info;
00185   info = ls_solve_chol(A, B, X);
00186   it_assert_debug(info, "ls_solve_chol: Failed solving the system");
00187   return X;
00188 }
00189 
00190 
00191 // --------- ls_solve ---------------------------------------------------------------
00192 #if defined(HAVE_LAPACK)
00193 
00194 bool ls_solve(const mat &A, const vec &b, vec &x)
00195 {
00196   int n, lda, ldb, nrhs, info;
00197   n = lda = ldb = A.rows();
00198   nrhs = 1;
00199 
00200   it_assert_debug(A.cols() == n, "ls_solve: System-matrix is not square");
00201   it_assert_debug(n == b.size(), "The number of rows in A must equal the length of b!");
00202 
00203   ivec ipiv(n);
00204   x = b;
00205   mat LU = A;
00206 
00207   dgesv_(&n, &nrhs, LU._data(), &lda, ipiv._data(), x._data(), &ldb, &info);
00208 
00209   return (info == 0);
00210 }
00211 
00212 bool ls_solve(const mat &A, const mat &B, mat &X)
00213 {
00214   int n, lda, ldb, nrhs, info;
00215   n = lda = ldb = A.rows();
00216   nrhs = B.cols();
00217 
00218   it_assert_debug(A.cols() == n, "ls_solve: System-matrix is not square");
00219   it_assert_debug(n == B.rows(), "The number of rows in A must equal the length of B!");
00220 
00221   ivec ipiv(n);
00222   X = B;
00223   mat LU = A;
00224 
00225   dgesv_(&n, &nrhs, LU._data(), &lda, ipiv._data(), X._data(), &ldb, &info);
00226 
00227   return (info == 0);
00228 }
00229 
00230 bool ls_solve(const cmat &A, const cvec &b, cvec &x)
00231 {
00232   int n, lda, ldb, nrhs, info;
00233   n = lda = ldb = A.rows();
00234   nrhs = 1;
00235 
00236   it_assert_debug(A.cols() == n, "ls_solve: System-matrix is not square");
00237   it_assert_debug(n == b.size(), "The number of rows in A must equal the length of b!");
00238 
00239   ivec ipiv(n);
00240   x = b;
00241   cmat LU = A;
00242 
00243   zgesv_(&n, &nrhs, LU._data(), &lda, ipiv._data(), x._data(), &ldb, &info);
00244 
00245   return (info == 0);
00246 }
00247 
00248 bool ls_solve(const cmat &A, const cmat &B, cmat &X)
00249 {
00250   int n, lda, ldb, nrhs, info;
00251   n = lda = ldb = A.rows();
00252   nrhs = B.cols();
00253 
00254   it_assert_debug(A.cols() == n, "ls_solve: System-matrix is not square");
00255   it_assert_debug(n == B.rows(), "The number of rows in A must equal the length of B!");
00256 
00257   ivec ipiv(n);
00258   X = B;
00259   cmat LU = A;
00260 
00261   zgesv_(&n, &nrhs, LU._data(), &lda, ipiv._data(), X._data(), &ldb, &info);
00262 
00263   return (info == 0);
00264 }
00265 
00266 #else
00267 
00268 bool ls_solve(const mat &A, const vec &b, vec &x)
00269 {
00270   it_error("LAPACK library is needed to use ls_solve() function");
00271   return false;
00272 }
00273 
00274 bool ls_solve(const mat &A, const mat &B, mat &X)
00275 {
00276   it_error("LAPACK library is needed to use ls_solve() function");
00277   return false;
00278 }
00279 
00280 bool ls_solve(const cmat &A, const cvec &b, cvec &x)
00281 {
00282   it_error("LAPACK library is needed to use ls_solve() function");
00283   return false;
00284 }
00285 
00286 bool ls_solve(const cmat &A, const cmat &B, cmat &X)
00287 {
00288   it_error("LAPACK library is needed to use ls_solve() function");
00289   return false;
00290 }
00291 
00292 #endif // HAVE_LAPACK
00293 
00294 vec ls_solve(const mat &A, const vec &b)
00295 {
00296   vec x;
00297   bool info;
00298   info = ls_solve(A, b, x);
00299   it_assert_debug(info, "ls_solve: Failed solving the system");
00300   return x;
00301 }
00302 
00303 mat ls_solve(const mat &A, const mat &B)
00304 {
00305   mat X;
00306   bool info;
00307   info = ls_solve(A, B, X);
00308   it_assert_debug(info, "ls_solve: Failed solving the system");
00309   return X;
00310 }
00311 
00312 cvec ls_solve(const cmat &A, const cvec &b)
00313 {
00314   cvec x;
00315   bool info;
00316   info = ls_solve(A, b, x);
00317   it_assert_debug(info, "ls_solve: Failed solving the system");
00318   return x;
00319 }
00320 
00321 cmat ls_solve(const cmat &A, const cmat &B)
00322 {
00323   cmat X;
00324   bool info;
00325   info = ls_solve(A, B, X);
00326   it_assert_debug(info, "ls_solve: Failed solving the system");
00327   return X;
00328 }
00329 
00330 
00331 // ----------------- ls_solve_od ------------------------------------------------------------------
00332 #if defined(HAVE_LAPACK)
00333 
00334 bool ls_solve_od(const mat &A, const vec &b, vec &x)
00335 {
00336   int m, n, lda, ldb, nrhs, lwork, info;
00337   char trans = 'N';
00338   m = lda = ldb = A.rows();
00339   n = A.cols();
00340   nrhs = 1;
00341   lwork = n + std::max(m, nrhs);
00342 
00343   it_assert_debug(m >= n, "The system is under-determined!");
00344   it_assert_debug(m == b.size(), "The number of rows in A must equal the length of b!");
00345 
00346   vec work(lwork);
00347   x = b;
00348   mat QR = A;
00349 
00350   dgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, x._data(), &ldb, work._data(), &lwork, &info);
00351   x.set_size(n, true);
00352 
00353   return (info == 0);
00354 }
00355 
00356 bool ls_solve_od(const mat &A, const mat &B, mat &X)
00357 {
00358   int m, n, lda, ldb, nrhs, lwork, info;
00359   char trans = 'N';
00360   m = lda = ldb = A.rows();
00361   n = A.cols();
00362   nrhs = B.cols();
00363   lwork = n + std::max(m, nrhs);
00364 
00365   it_assert_debug(m >= n, "The system is under-determined!");
00366   it_assert_debug(m == B.rows(), "The number of rows in A must equal the length of b!");
00367 
00368   vec work(lwork);
00369   X = B;
00370   mat QR = A;
00371 
00372   dgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, X._data(), &ldb, work._data(), &lwork, &info);
00373   X.set_size(n, nrhs, true);
00374 
00375   return (info == 0);
00376 }
00377 
00378 bool ls_solve_od(const cmat &A, const cvec &b, cvec &x)
00379 {
00380   int m, n, lda, ldb, nrhs, lwork, info;
00381   char trans = 'N';
00382   m = lda = ldb = A.rows();
00383   n = A.cols();
00384   nrhs = 1;
00385   lwork = n + std::max(m, nrhs);
00386 
00387   it_assert_debug(m >= n, "The system is under-determined!");
00388   it_assert_debug(m == b.size(), "The number of rows in A must equal the length of b!");
00389 
00390   cvec work(lwork);
00391   x = b;
00392   cmat QR = A;
00393 
00394   zgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, x._data(), &ldb, work._data(), &lwork, &info);
00395   x.set_size(n, true);
00396 
00397   return (info == 0);
00398 }
00399 
00400 bool ls_solve_od(const cmat &A, const cmat &B, cmat &X)
00401 {
00402   int m, n, lda, ldb, nrhs, lwork, info;
00403   char trans = 'N';
00404   m = lda = ldb = A.rows();
00405   n = A.cols();
00406   nrhs = B.cols();
00407   lwork = n + std::max(m, nrhs);
00408 
00409   it_assert_debug(m >= n, "The system is under-determined!");
00410   it_assert_debug(m == B.rows(), "The number of rows in A must equal the length of b!");
00411 
00412   cvec work(lwork);
00413   X = B;
00414   cmat QR = A;
00415 
00416   zgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, X._data(), &ldb, work._data(), &lwork, &info);
00417   X.set_size(n, nrhs, true);
00418 
00419   return (info == 0);
00420 }
00421 
00422 #else
00423 
00424 bool ls_solve_od(const mat &A, const vec &b, vec &x)
00425 {
00426   it_error("LAPACK library is needed to use ls_solve_od() function");
00427   return false;
00428 }
00429 
00430 bool ls_solve_od(const mat &A, const mat &B, mat &X)
00431 {
00432   it_error("LAPACK library is needed to use ls_solve_od() function");
00433   return false;
00434 }
00435 
00436 bool ls_solve_od(const cmat &A, const cvec &b, cvec &x)
00437 {
00438   it_error("LAPACK library is needed to use ls_solve_od() function");
00439   return false;
00440 }
00441 
00442 bool ls_solve_od(const cmat &A, const cmat &B, cmat &X)
00443 {
00444   it_error("LAPACK library is needed to use ls_solve_od() function");
00445   return false;
00446 }
00447 
00448 #endif // HAVE_LAPACK
00449 
00450 vec ls_solve_od(const mat &A, const vec &b)
00451 {
00452   vec x;
00453   bool info;
00454   info = ls_solve_od(A, b, x);
00455   it_assert_debug(info, "ls_solve_od: Failed solving the system");
00456   return x;
00457 }
00458 
00459 mat ls_solve_od(const mat &A, const mat &B)
00460 {
00461   mat X;
00462   bool info;
00463   info = ls_solve_od(A, B, X);
00464   it_assert_debug(info, "ls_solve_od: Failed solving the system");
00465   return X;
00466 }
00467 
00468 cvec ls_solve_od(const cmat &A, const cvec &b)
00469 {
00470   cvec x;
00471   bool info;
00472   info = ls_solve_od(A, b, x);
00473   it_assert_debug(info, "ls_solve_od: Failed solving the system");
00474   return x;
00475 }
00476 
00477 cmat ls_solve_od(const cmat &A, const cmat &B)
00478 {
00479   cmat X;
00480   bool info;
00481   info = ls_solve_od(A, B, X);
00482   it_assert_debug(info, "ls_solve_od: Failed solving the system");
00483   return X;
00484 }
00485 
00486 // ------------------- ls_solve_ud -----------------------------------------------------------
00487 #if defined(HAVE_LAPACK)
00488 
00489 bool ls_solve_ud(const mat &A, const vec &b, vec &x)
00490 {
00491   int m, n, lda, ldb, nrhs, lwork, info;
00492   char trans = 'N';
00493   m = lda = A.rows();
00494   n = A.cols();
00495   ldb = n;
00496   nrhs = 1;
00497   lwork = m + std::max(n, nrhs);
00498 
00499   it_assert_debug(m < n, "The system is over-determined!");
00500   it_assert_debug(m == b.size(), "The number of rows in A must equal the length of b!");
00501 
00502   vec work(lwork);
00503   x = b;
00504   x.set_size(n, true);
00505   mat QR = A;
00506 
00507   dgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, x._data(), &ldb, work._data(), &lwork, &info);
00508 
00509   return (info == 0);
00510 }
00511 
00512 bool ls_solve_ud(const mat &A, const mat &B, mat &X)
00513 {
00514   int m, n, lda, ldb, nrhs, lwork, info;
00515   char trans = 'N';
00516   m = lda = A.rows();
00517   n = A.cols();
00518   ldb = n;
00519   nrhs = B.cols();
00520   lwork = m + std::max(n, nrhs);
00521 
00522   it_assert_debug(m < n, "The system is over-determined!");
00523   it_assert_debug(m == B.rows(), "The number of rows in A must equal the length of b!");
00524 
00525   vec work(lwork);
00526   X = B;
00527   X.set_size(n, std::max(m, nrhs), true);
00528   mat QR = A;
00529 
00530   dgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, X._data(), &ldb, work._data(), &lwork, &info);
00531   X.set_size(n, nrhs, true);
00532 
00533   return (info == 0);
00534 }
00535 
00536 bool ls_solve_ud(const cmat &A, const cvec &b, cvec &x)
00537 {
00538   int m, n, lda, ldb, nrhs, lwork, info;
00539   char trans = 'N';
00540   m = lda = A.rows();
00541   n = A.cols();
00542   ldb = n;
00543   nrhs = 1;
00544   lwork = m + std::max(n, nrhs);
00545 
00546   it_assert_debug(m < n, "The system is over-determined!");
00547   it_assert_debug(m == b.size(), "The number of rows in A must equal the length of b!");
00548 
00549   cvec work(lwork);
00550   x = b;
00551   x.set_size(n, true);
00552   cmat QR = A;
00553 
00554   zgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, x._data(), &ldb, work._data(), &lwork, &info);
00555 
00556   return (info == 0);
00557 }
00558 
00559 bool ls_solve_ud(const cmat &A, const cmat &B, cmat &X)
00560 {
00561   int m, n, lda, ldb, nrhs, lwork, info;
00562   char trans = 'N';
00563   m = lda = A.rows();
00564   n = A.cols();
00565   ldb = n;
00566   nrhs = B.cols();
00567   lwork = m + std::max(n, nrhs);
00568 
00569   it_assert_debug(m < n, "The system is over-determined!");
00570   it_assert_debug(m == B.rows(), "The number of rows in A must equal the length of b!");
00571 
00572   cvec work(lwork);
00573   X = B;
00574   X.set_size(n, std::max(m, nrhs), true);
00575   cmat QR = A;
00576 
00577   zgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, X._data(), &ldb, work._data(), &lwork, &info);
00578   X.set_size(n, nrhs, true);
00579 
00580   return (info == 0);
00581 }
00582 
00583 #else
00584 
00585 bool ls_solve_ud(const mat &A, const vec &b, vec &x)
00586 {
00587   it_error("LAPACK library is needed to use ls_solve_ud() function");
00588   return false;
00589 }
00590 
00591 bool ls_solve_ud(const mat &A, const mat &B, mat &X)
00592 {
00593   it_error("LAPACK library is needed to use ls_solve_ud() function");
00594   return false;
00595 }
00596 
00597 bool ls_solve_ud(const cmat &A, const cvec &b, cvec &x)
00598 {
00599   it_error("LAPACK library is needed to use ls_solve_ud() function");
00600   return false;
00601 }
00602 
00603 bool ls_solve_ud(const cmat &A, const cmat &B, cmat &X)
00604 {
00605   it_error("LAPACK library is needed to use ls_solve_ud() function");
00606   return false;
00607 }
00608 
00609 #endif // HAVE_LAPACK
00610 
00611 
00612 vec ls_solve_ud(const mat &A, const vec &b)
00613 {
00614   vec x;
00615   bool info;
00616   info = ls_solve_ud(A, b, x);
00617   it_assert_debug(info, "ls_solve_ud: Failed solving the system");
00618   return x;
00619 }
00620 
00621 mat ls_solve_ud(const mat &A, const mat &B)
00622 {
00623   mat X;
00624   bool info;
00625   info = ls_solve_ud(A, B, X);
00626   it_assert_debug(info, "ls_solve_ud: Failed solving the system");
00627   return X;
00628 }
00629 
00630 cvec ls_solve_ud(const cmat &A, const cvec &b)
00631 {
00632   cvec x;
00633   bool info;
00634   info = ls_solve_ud(A, b, x);
00635   it_assert_debug(info, "ls_solve_ud: Failed solving the system");
00636   return x;
00637 }
00638 
00639 cmat ls_solve_ud(const cmat &A, const cmat &B)
00640 {
00641   cmat X;
00642   bool info;
00643   info = ls_solve_ud(A, B, X);
00644   it_assert_debug(info, "ls_solve_ud: Failed solving the system");
00645   return X;
00646 }
00647 
00648 
00649 // ---------------------- backslash -----------------------------------------
00650 
00651 bool backslash(const mat &A, const vec &b, vec &x)
00652 {
00653   int m = A.rows(), n = A.cols();
00654   bool info;
00655 
00656   if (m == n)
00657     info = ls_solve(A, b, x);
00658   else if (m > n)
00659     info = ls_solve_od(A, b, x);
00660   else
00661     info = ls_solve_ud(A, b, x);
00662 
00663   return info;
00664 }
00665 
00666 
00667 vec backslash(const mat &A, const vec &b)
00668 {
00669   vec x;
00670   bool info;
00671   info = backslash(A, b, x);
00672   it_assert_debug(info, "backslash(): solution was not found");
00673   return x;
00674 }
00675 
00676 
00677 bool backslash(const mat &A, const mat &B, mat &X)
00678 {
00679   int m = A.rows(), n = A.cols();
00680   bool info;
00681 
00682   if (m == n)
00683     info = ls_solve(A, B, X);
00684   else if (m > n)
00685     info = ls_solve_od(A, B, X);
00686   else
00687     info = ls_solve_ud(A, B, X);
00688 
00689   return info;
00690 }
00691 
00692 
00693 mat backslash(const mat &A, const mat &B)
00694 {
00695   mat X;
00696   bool info;
00697   info = backslash(A, B, X);
00698   it_assert_debug(info, "backslash(): solution was not found");
00699   return X;
00700 }
00701 
00702 
00703 bool backslash(const cmat &A, const cvec &b, cvec &x)
00704 {
00705   int m = A.rows(), n = A.cols();
00706   bool info;
00707 
00708   if (m == n)
00709     info = ls_solve(A, b, x);
00710   else if (m > n)
00711     info = ls_solve_od(A, b, x);
00712   else
00713     info = ls_solve_ud(A, b, x);
00714 
00715   return info;
00716 }
00717 
00718 
00719 cvec backslash(const cmat &A, const cvec &b)
00720 {
00721   cvec x;
00722   bool info;
00723   info = backslash(A, b, x);
00724   it_assert_debug(info, "backslash(): solution was not found");
00725   return x;
00726 }
00727 
00728 
00729 bool backslash(const cmat &A, const cmat &B, cmat &X)
00730 {
00731   int m = A.rows(), n = A.cols();
00732   bool info;
00733 
00734   if (m == n)
00735     info = ls_solve(A, B, X);
00736   else if (m > n)
00737     info = ls_solve_od(A, B, X);
00738   else
00739     info = ls_solve_ud(A, B, X);
00740 
00741   return info;
00742 }
00743 
00744 cmat backslash(const cmat &A, const cmat &B)
00745 {
00746   cmat X;
00747   bool info;
00748   info = backslash(A, B, X);
00749   it_assert_debug(info, "backslash(): solution was not found");
00750   return X;
00751 }
00752 
00753 
00754 // --------------------------------------------------------------------------
00755 
00756 vec forward_substitution(const mat &L, const vec &b)
00757 {
00758   int n = L.rows();
00759   vec x(n);
00760 
00761   forward_substitution(L, b, x);
00762 
00763   return x;
00764 }
00765 
00766 void forward_substitution(const mat &L, const vec &b, vec &x)
00767 {
00768   it_assert(L.rows() == L.cols() && L.cols() == b.size() && b.size() == x.size(),
00769             "forward_substitution: dimension mismatch");
00770   int n = L.rows(), i, j;
00771   double temp;
00772 
00773   x(0) = b(0) / L(0, 0);
00774   for (i = 1;i < n;i++) {
00775     // Should be: x(i)=((b(i)-L(i,i,0,i-1)*x(0,i-1))/L(i,i))(0); but this is to slow.
00776     //i_pos=i*L._row_offset();
00777     temp = 0;
00778     for (j = 0; j < i; j++) {
00779       temp += L._elem(i, j) * x(j);
00780       //temp+=L._data()[i_pos+j]*x(j);
00781     }
00782     x(i) = (b(i) - temp) / L._elem(i, i);
00783     //x(i)=(b(i)-temp)/L._data()[i_pos+i];
00784   }
00785 }
00786 
00787 vec forward_substitution(const mat &L, int p, const vec &b)
00788 {
00789   int n = L.rows();
00790   vec x(n);
00791 
00792   forward_substitution(L, p, b, x);
00793 
00794   return x;
00795 }
00796 
00797 void forward_substitution(const mat &L, int p, const vec &b, vec &x)
00798 {
00799   it_assert(L.rows() == L.cols() && L.cols() == b.size() && b.size() == x.size() && p <= L.rows() / 2,
00800             "forward_substitution: dimension mismatch");
00801   int n = L.rows(), i, j;
00802 
00803   x = b;
00804 
00805   for (j = 0;j < n;j++) {
00806     x(j) /= L(j, j);
00807     for (i = j + 1;i < std::min(j + p + 1, n);i++) {
00808       x(i) -= L(i, j) * x(j);
00809     }
00810   }
00811 }
00812 
00813 vec backward_substitution(const mat &U, const vec &b)
00814 {
00815   vec x(U.rows());
00816   backward_substitution(U, b, x);
00817 
00818   return x;
00819 }
00820 
00821 void backward_substitution(const mat &U, const vec &b, vec &x)
00822 {
00823   it_assert(U.rows() == U.cols() && U.cols() == b.size() && b.size() == x.size(),
00824             "backward_substitution: dimension mismatch");
00825   int n = U.rows(), i, j;
00826   double temp;
00827 
00828   x(n - 1) = b(n - 1) / U(n - 1, n - 1);
00829   for (i = n - 2; i >= 0; i--) {
00830     // Should be: x(i)=((b(i)-U(i,i,i+1,n-1)*x(i+1,n-1))/U(i,i))(0); but this is too slow.
00831     temp = 0;
00832     //i_pos=i*U._row_offset();
00833     for (j = i + 1; j < n; j++) {
00834       temp += U._elem(i, j) * x(j);
00835       //temp+=U._data()[i_pos+j]*x(j);
00836     }
00837     x(i) = (b(i) - temp) / U._elem(i, i);
00838     //x(i)=(b(i)-temp)/U._data()[i_pos+i];
00839   }
00840 }
00841 
00842 vec backward_substitution(const mat &U, int q, const vec &b)
00843 {
00844   vec x(U.rows());
00845   backward_substitution(U, q, b, x);
00846 
00847   return x;
00848 }
00849 
00850 void backward_substitution(const mat &U, int q, const vec &b, vec &x)
00851 {
00852   it_assert(U.rows() == U.cols() && U.cols() == b.size() && b.size() == x.size() && q <= U.rows() / 2,
00853             "backward_substitution: dimension mismatch");
00854   int n = U.rows(), i, j;
00855 
00856   x = b;
00857 
00858   for (j = n - 1; j >= 0; j--) {
00859     x(j) /= U(j, j);
00860     for (i = std::max(0, j - q); i < j; i++) {
00861       x(i) -= U(i, j) * x(j);
00862     }
00863   }
00864 }
00865 
00866 } // namespace itpp
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
SourceForge Logo

Generated on Sat Jul 9 2011 15:21:29 for IT++ by Doxygen 1.7.4