|
32 | 32 |
|
33 | 33 | namespace rk4_solver |
34 | 34 | { |
35 | | -/* |
36 | | - * Computes the next Runge-Kutta 4th Order step. |
37 | | - * `ode_fun` can be parametrized using the time (row) index `i`. |
38 | | - * |
39 | | - * `step<OPT: X_DIM, T>(obj, ode_fun, t, x, h, i, OUT:x_next)` |
40 | | - * |
41 | | - * 1. `obj`: dynamics object (type `T`) |
42 | | - * 2. `ode_fun`: ode function, member of `obj` (type `T::*`) |
43 | | - * 3. `t`: time [s] |
44 | | - * 4. `x`: state |
45 | | - * 5. `h`: time step [s] |
46 | | - * 6. `i`: time index corresponding to `t` |
47 | | - * |
48 | | - * OUT: |
49 | | - * 7. `x_next`: next state |
50 | | - */ |
51 | | -template <size_t X_DIM, typename T> |
52 | | -void |
53 | | -step(T &obj, OdeFun_T<X_DIM, T> ode_fun, const Real_T t, const Real_T (&x)[X_DIM], const Real_T h, |
54 | | - const size_t i, Real_T (&x_next)[X_DIM]) |
| 35 | +template <size_t X_DIM, typename T> class Integrator |
55 | 36 | { |
56 | | - constexpr Real_T rk4_weight_0 = 1. / 6.; |
57 | | - constexpr Real_T rk4_weight_1 = 1. / 3.; |
58 | | -#ifdef DO_NOT_USE_HEAP |
59 | | - static Real_T k_0[X_DIM]; |
60 | | - static Real_T k_1[X_DIM]; |
61 | | - static Real_T k_2[X_DIM]; |
62 | | - static Real_T k_3[X_DIM]; |
63 | | - static Real_T x_temp[X_DIM]; |
64 | | -#else |
| 37 | + public: |
| 38 | + Integrator() |
| 39 | + { |
| 40 | + for (size_t i = 0; i < X_DIM; ++i) { |
| 41 | + accumulator[i] = 0; |
| 42 | + } |
| 43 | + } |
| 44 | + |
65 | 45 | /* |
66 | | - * `..._ptr`s are of type `Real_T(*)[X_DIM]`. |
67 | | - * They point to `Real_T[X_DIM]`s which are allocated on the heap. |
68 | | - * Dereferencing them gives us rvalue references to `Real_T[X_DIM]`s, |
69 | | - * which can be substituted for `Real_T[X_DIM]`s allocated on the stack. |
70 | | - * (Maybe typedef should be used more.) |
| 46 | + * Computes the next Runge-Kutta 4th Order step. |
| 47 | + * `ode_fun` can be parametrized using the time (row) index `i`. |
| 48 | + * |
| 49 | + * `step<OPT: X_DIM, T>(obj, ode_fun, t, x, h, i, OUT:x_next)` |
| 50 | + * |
| 51 | + * 1. `obj`: dynamics object (type `T`) |
| 52 | + * 2. `ode_fun`: ode function, member of `obj` (type `T::*`) |
| 53 | + * 3. `t`: time [s] |
| 54 | + * 4. `x`: state |
| 55 | + * 5. `h`: time step [s] |
| 56 | + * 6. `i`: time index corresponding to `t` |
| 57 | + * |
| 58 | + * OUT: |
| 59 | + * 7. `x_next`: next state |
71 | 60 | */ |
72 | | - static Real_T(*k_0_ptr)[X_DIM] = (Real_T(*)[X_DIM]) new Real_T[X_DIM]; |
73 | | - static Real_T(*k_1_ptr)[X_DIM] = (Real_T(*)[X_DIM]) new Real_T[X_DIM]; |
74 | | - static Real_T(*k_2_ptr)[X_DIM] = (Real_T(*)[X_DIM]) new Real_T[X_DIM]; |
75 | | - static Real_T(*k_3_ptr)[X_DIM] = (Real_T(*)[X_DIM]) new Real_T[X_DIM]; |
76 | | - static Real_T(*x_temp_ptr)[X_DIM] = (Real_T(*)[X_DIM]) new Real_T[X_DIM]; |
77 | | - Real_T(*dx_ptr)[X_DIM] = (Real_T(*)[X_DIM]) new Real_T[X_DIM]{}; //* not static, needs to be zeroed before every loop |
78 | | - Real_T(&k_0)[X_DIM] = *k_0_ptr; |
79 | | - Real_T(&k_1)[X_DIM] = *k_1_ptr; |
80 | | - Real_T(&k_2)[X_DIM] = *k_2_ptr; |
81 | | - Real_T(&k_3)[X_DIM] = *k_3_ptr; |
82 | | - Real_T(&x_temp)[X_DIM] = *x_temp_ptr; |
83 | | - Real_T(&dx)[X_DIM] = *dx_ptr; |
84 | | -#endif |
| 61 | + void |
| 62 | + step(T &obj, OdeFun_T<X_DIM, T> ode_fun, const Real_T t, const Real_T (&x)[X_DIM], |
| 63 | + const Real_T h, const size_t i, Real_T (&x_next)[X_DIM]) |
| 64 | + { |
| 65 | + (obj.*ode_fun)(t, x, i, k_0); //* ode_fun(ti, xi) |
85 | 66 |
|
86 | | - (obj.*ode_fun)(t, x, i, k_0); //* ode_fun(ti, xi) |
| 67 | + //* zero-order hold, i.e. no ODE_FUN(,, i+.5), ODE_FUN(,, i+1,) etc. |
| 68 | + matrix_op::weighted_sum(h / 2, k_0, 1., x, x_temp); |
| 69 | + (obj.*ode_fun)(t + h / 2, x_temp, i, k_1); //* ode_fun(ti + h/2, xi + h/2*k_0) |
87 | 70 |
|
88 | | - //* zero-order hold, i.e. no ODE_FUN(,, i+.5), ODE_FUN(,, i+1,) etc. |
89 | | - matrix_op::weighted_sum(h / 2, k_0, 1., x, x_temp); |
90 | | - (obj.*ode_fun)(t + h / 2, x_temp, i, k_1); //* ode_fun(ti + h/2, xi + h/2*k_0) |
| 71 | + matrix_op::weighted_sum(h / 2, k_1, 1., x, x_temp); |
| 72 | + (obj.*ode_fun)(t + h / 2, x_temp, i, k_2); //* ode_fun(ti + h/2, xi + h/2*k_1) |
91 | 73 |
|
92 | | - matrix_op::weighted_sum(h / 2, k_1, 1., x, x_temp); |
93 | | - (obj.*ode_fun)(t + h / 2, x_temp, i, k_2); //* ode_fun(ti + h/2, xi + h/2*k_1) |
| 74 | + matrix_op::weighted_sum(h, k_2, 1., x, x_temp); |
| 75 | + (obj.*ode_fun)(t + h, x_temp, i, k_3); //* ode_fun(ti + h, xi + k_2) |
94 | 76 |
|
95 | | - matrix_op::weighted_sum(h, k_2, 1., x, x_temp); |
96 | | - (obj.*ode_fun)(t + h, x_temp, i, k_3); //* ode_fun(ti + h, xi + k_2) |
| 77 | + constexpr Real_T w0 = 1. / 6.; |
| 78 | + constexpr Real_T w1 = 1. / 3.; |
97 | 79 |
|
98 | | - //* compensated summation (Kahan summation), probably ffast-math would break it |
99 | | - for (size_t i = 0; i < X_DIM; ++i) { |
100 | | - dx[i] += h * |
101 | | - (rk4_weight_0 * k_0[i] + rk4_weight_1 * k_1[i] + rk4_weight_1 * k_2[i] + |
102 | | - rk4_weight_0 * k_3[i]); //* dx accumulates floating point errors |
103 | | - x_temp[i] = x[i]; //* stores x when x_next is pointing to x's address |
104 | | - x_next[i] = x[i] + dx[i]; //* uncompensated summation |
105 | | - dx[i] -= (x_next[i] - x_temp[i]); //* removes the uncompensated summation from the accumulated error |
| 80 | + for (size_t i = 0; i < X_DIM; ++i) { |
| 81 | + dx = h * (w0 * k_0[i] + w1 * k_1[i] + w1 * k_2[i] + w0 * k_3[i]); |
| 82 | + //* compensated (Kahan) summation, ffast-math might break this |
| 83 | + compensated_dx = dx - accumulator[i]; |
| 84 | + x_temp[i] = x[i] + compensated_dx; |
| 85 | + accumulator[i] = (x_temp[i] - x[i]) - compensated_dx; |
| 86 | + x_next[i] = x_temp[i]; |
| 87 | + } |
106 | 88 | } |
107 | | -} |
| 89 | + |
| 90 | + private: |
| 91 | + Real_T dx; |
| 92 | + Real_T compensated_dx; |
| 93 | + |
| 94 | +#ifdef DO_NOT_USE_HEAP |
| 95 | + Real_T k_0[X_DIM]; |
| 96 | + Real_T k_1[X_DIM]; |
| 97 | + Real_T k_2[X_DIM]; |
| 98 | + Real_T k_3[X_DIM]; |
| 99 | + Real_T x_temp[X_DIM]; |
| 100 | + Real_T accumulator[X_DIM]; |
| 101 | +#else |
| 102 | + /* |
| 103 | + * Dereferencing pointers that point to `Real_T[X_DIM]`s which are allocated on the heap, in |
| 104 | + * order to get rvalue references to the `Real_T[X_DIM]`s. |
| 105 | + */ |
| 106 | + Real_T (&k_0)[X_DIM] = *(Real_T(*)[X_DIM]) new Real_T[X_DIM]; |
| 107 | + Real_T (&k_1)[X_DIM] = *(Real_T(*)[X_DIM]) new Real_T[X_DIM]; |
| 108 | + Real_T (&k_2)[X_DIM] = *(Real_T(*)[X_DIM]) new Real_T[X_DIM]; |
| 109 | + Real_T (&k_3)[X_DIM] = *(Real_T(*)[X_DIM]) new Real_T[X_DIM]; |
| 110 | + Real_T (&x_temp)[X_DIM] = *(Real_T(*)[X_DIM]) new Real_T[X_DIM]; |
| 111 | + Real_T (&accumulator)[X_DIM] = *(Real_T(*)[X_DIM]) new Real_T[X_DIM]; |
| 112 | +#endif |
| 113 | +}; |
108 | 114 | } // namespace rk4_solver |
109 | 115 |
|
110 | 116 | #endif |
0 commit comments