Loading...
Searching...
No Matches
CG-Impl.hpp
Go to the documentation of this file.
1// This code is based on Jet framework.
2// Copyright (c) 2018 Doyub Kim
3// CubbyFlow is voxel-based fluid simulation engine for computer games.
4// Copyright (c) 2020 CubbyFlow Team
5// Core Part: Chris Ohk, Junwoo Hwang, Jihong Sin, Seungwoo Yoo
6// AI Part: Dongheon Cho, Minseo Kim
7// We are making my contributions/submissions to this project solely in our
8// personal capacity and are not conveying any rights to any intellectual
9// property of any third parties.
10
11#ifndef CUBBYFLOW_CG_IMPL_HPP
12#define CUBBYFLOW_CG_IMPL_HPP
13
15
16namespace CubbyFlow
17{
18template <typename BLASType>
19void CG(const typename BLASType::MatrixType& A,
20 const typename BLASType::VectorType& b,
21 unsigned int maxNumberOfIterations, double tolerance,
22 typename BLASType::VectorType* x, typename BLASType::VectorType* r,
23 typename BLASType::VectorType* d, typename BLASType::VectorType* q,
24 typename BLASType::VectorType* s, unsigned int* lastNumberOfIterations,
25 double* lastResidualNorm)
26{
29
31 x, r, d, q, s, lastNumberOfIterations,
32 lastResidualNorm);
33}
34
35template <typename BLASType, typename PrecondType>
36void PCG(const typename BLASType::MatrixType& A,
37 const typename BLASType::VectorType& b,
38 unsigned int maxNumberOfIterations, double tolerance, PrecondType* M,
39 typename BLASType::VectorType* x, typename BLASType::VectorType* r,
40 typename BLASType::VectorType* d, typename BLASType::VectorType* q,
41 typename BLASType::VectorType* s, unsigned int* lastNumberOfIterations,
42 double* lastResidualNorm)
43{
44 // Clear
45 BLASType::Set(0, r);
46 BLASType::Set(0, d);
47 BLASType::Set(0, q);
48 BLASType::Set(0, s);
49
50 // r = b - Ax
51 BLASType::Residual(A, *x, b, r);
52
53 // d = M^-1r
54 M->Solve(*r, d);
55
56 // sigmaNew = r.d
57 double sigmaNew = BLASType::Dot(*r, *d);
58
59 unsigned int iter = 0;
60 bool trigger = false;
61
62 while (sigmaNew > Square(tolerance) && iter < maxNumberOfIterations)
63 {
64 // q = Ad
65 BLASType::MVM(A, *d, q);
66
67 // alpha = sigmaNew / d.q
68 double alpha = sigmaNew / BLASType::Dot(*d, *q);
69
70 // x = x + alpha * d
71 BLASType::AXPlusY(alpha, *d, *x, x);
72
73 // if i is divisible by 50...
74 if (trigger || (iter % 50 == 0 && iter > 0))
75 {
76 // r = b - Ax
77 BLASType::Residual(A, *x, b, r);
78 trigger = false;
79 }
80 else
81 {
82 // r = r - alpha * q
83 BLASType::AXPlusY(-alpha, *q, *r, r);
84 }
85
86 // s = M^-1r
87 M->Solve(*r, s);
88
89 // sigmaOld = sigmaNew
90 const double sigmaOld = sigmaNew;
91
92 // sigmaNew = r.s
94
95 if (sigmaNew > sigmaOld)
96 {
97 trigger = true;
98 }
99
100 // beta = sigmaNew / sigmaOld
101 double beta = sigmaNew / sigmaOld;
102
103 // d = s + beta*d
104 BLASType::AXPlusY(beta, *d, *s, d);
105
106 ++iter;
107 }
108
110
111 // std::fabs(sigmaNew) - Workaround for negative zero
112 *lastResidualNorm = std::sqrt(std::fabs(sigmaNew));
113}
114} // namespace CubbyFlow
115
116#endif
std::enable_if_t<(IsMatrixSizeDynamic< Rows, Cols >()||Cols==1) &&(IsMatrixSizeDynamic< R, C >()||C==1), U > Dot(const MatrixExpression< T, R, C, E > &expression) const
Definition MatrixExpression-Impl.hpp:391
Definition Matrix.hpp:30
Definition pybind11Utils.hpp:21
std::enable_if_t< std::is_arithmetic< T >::value, T > Square(T x)
Returns the square of x.
Definition MathUtils-Impl.hpp:154
Matrix< T, Rows, 1 > Vector
Definition Matrix.hpp:738
void PCG(const typename BLASType::MatrixType &A, const typename BLASType::VectorType &b, unsigned int maxNumberOfIterations, double tolerance, PrecondType *M, typename BLASType::VectorType *x, typename BLASType::VectorType *r, typename BLASType::VectorType *d, typename BLASType::VectorType *q, typename BLASType::VectorType *s, unsigned int *lastNumberOfIterations, double *lastResidualNorm)
Solves pre-conditioned conjugate gradient.
Definition CG-Impl.hpp:36
void CG(const typename BLASType::MatrixType &A, const typename BLASType::VectorType &b, unsigned int maxNumberOfIterations, double tolerance, typename BLASType::VectorType *x, typename BLASType::VectorType *r, typename BLASType::VectorType *d, typename BLASType::VectorType *q, typename BLASType::VectorType *s, unsigned int *lastNumberOfIterations, double *lastResidualNorm)
Solves conjugate gradient.
Definition CG-Impl.hpp:19