sablib
Loading...
Searching...
No Matches
cubic_spline.h
Go to the documentation of this file.
1
6
7#ifndef __SABLIB_CUBIC_SPLINE_H__
8#define __SABLIB_CUBIC_SPLINE_H__
9
10#include <stdexcept>
11#include <type_traits>
12#include <vector>
13#include <Eigen/Eigen>
14
15namespace sablib {
16
22template <typename Scalar>
23class CubicSpline final
24{
25 static_assert(std::is_floating_point_v<Scalar>, "Scalar must be a floating-point type.");
26
27public:
31 CubicSpline() = default;
32
39 CubicSpline(const Eigen::VectorX<Scalar> & x, const Eigen::VectorX<Scalar> & y);
40
48 void Fit(const Eigen::VectorX<Scalar> & x, const Eigen::VectorX<Scalar> & y);
49
56 Scalar Interpolate(const double x) const;
57
64 Scalar operator()(const double x) const
65 {
66 return Interpolate(x);
67 }
68
69private:
70 Eigen::VectorX<Scalar> a, b, c, d;
71 Eigen::VectorX<Scalar> sp_x, sp_y;
72
82 Eigen::VectorX<Scalar> SolveTridiagonal(
83 const Eigen::VectorX<Scalar>& lower, const Eigen::VectorX<Scalar>& diag,
84 const Eigen::VectorX<Scalar>& upper, const Eigen::VectorX<Scalar>& rhs
85 );
86};
87
88//
89// Implementation of constructor
90//
91template <typename Scalar>
92inline CubicSpline<Scalar>::CubicSpline(const Eigen::VectorX<Scalar>& x, const Eigen::VectorX<Scalar>& y)
93{
94 Fit(x, y);
95}
96
97//
98// Implementation of Fit() method
99//
100template <typename Scalar>
101void CubicSpline<Scalar>::Fit(const Eigen::VectorX<Scalar> & x, const Eigen::VectorX<Scalar> & y)
102{
103 if (x.size() != y.size()) {
104 throw std::invalid_argument("CubicSpline::Fit(): x and y must have the same size.");
105 }
106
107 if (x.size() < 2) {
108 throw std::invalid_argument("CubicSpline::Fit(): At least 2 points are required.");
109 }
110
111 sp_x = x;
112 sp_y = y;
113
114 int n = x.size() - 1;
115 Eigen::VectorX<Scalar> h(n);
116
117 for(int i = 0; i < n; i++) {
118 h(i) = x(i + 1) - x(i);
119 if (h(i) <= 0) {
120 throw std::invalid_argument("CubicSpline::Fit(): x must be strictly increasing.");
121 }
122 }
123
124 Eigen::VectorX<Scalar> lower = Eigen::VectorX<Scalar>::Zero(n + 1);
125 Eigen::VectorX<Scalar> diag = Eigen::VectorX<Scalar>::Ones(n + 1);
126 Eigen::VectorX<Scalar> upper = Eigen::VectorX<Scalar>::Zero(n + 1);
127 Eigen::VectorX<Scalar> rhs = Eigen::VectorX<Scalar>::Zero(n + 1);
128
129 for(int i = 1; i < n; i++) {
130 lower(i) = h(i - 1);
131 diag(i) = 2 * (h(i - 1) + h(i));
132 upper(i) = h(i);
133 rhs(i) = 6 * ((y(i + 1) - y(i)) / h(i) - (y(i) - y(i - 1)) / h(i - 1));
134 }
135
136 Eigen::VectorX<Scalar> m = SolveTridiagonal(lower, diag, upper, rhs);
137
138 a.resize(n);
139 b.resize(n);
140 c.resize(n);
141 d.resize(n);
142
143 for(int i=0;i<n;i++) {
144 a(i) = y(i);
145 b(i) = (y(i + 1) - y(i)) / h(i) - h(i) * (2 * m(i) + m(i + 1)) / 6.0;
146 c(i) = m(i) / 2.0;
147 d(i) = (m(i + 1) - m(i)) / (6.0 * h(i));
148 }
149}
150
151//
152// Implementation of Thomas' algorithm (TDMA)
153//
154template <typename Scalar>
155Eigen::VectorX<Scalar> CubicSpline<Scalar>::SolveTridiagonal(
156 const Eigen::VectorX<Scalar>& lower, const Eigen::VectorX<Scalar>& diag,
157 const Eigen::VectorX<Scalar>& upper, const Eigen::VectorX<Scalar>& rhs
158)
159{
160 int n = diag.size();
161 Eigen::VectorX<Scalar> cp = Eigen::VectorX<Scalar>::Zero(n);
162 Eigen::VectorX<Scalar> dp = Eigen::VectorX<Scalar>::Zero(n);
163 Eigen::VectorX<Scalar> x = Eigen::VectorX<Scalar>::Zero(n);
164
165 // Forward elimination
166 cp(0) = upper(0) / diag(0);
167 dp(0) = rhs(0) / diag(0);
168
169 for (int i = 1; i < n; i++) {
170 Scalar m = diag(i) - lower(i) * cp(i - 1);
171 cp(i) = upper(i) / m;
172 dp(i) = (rhs(i) - lower(i) * dp(i - 1)) / m;
173 }
174
175 // Backward substitution
176 x(n - 1) = dp(n - 1);
177 for (int i = n - 2; i >= 0; i--) {
178 x(i) = dp(i) - cp(i) * x(i + 1);
179 }
180
181 return x;
182}
183
184//
185// Implementation of Interpolate() method
186//
187template <typename Scalar>
188Scalar CubicSpline<Scalar>::Interpolate(const double x) const
189{
190 int n = sp_x.size() - 1;
191 int i = n - 1;
192
193 for(int j = 0; j < n; j++) {
194 if(x >= sp_x(j) && x <= sp_x(j + 1)) {
195 i = j;
196 break;
197 }
198 }
199
200 double dx = x - sp_x(i);
201
202 return a(i) + b(i) * dx + c(i) * dx * dx + d(i) * dx * dx * dx;
203}
204
205}; // namespace sablib
206
207#endif // __SABLIB_CUBIC_SPLINE_H__
void Fit(const Eigen::VectorX< Scalar > &x, const Eigen::VectorX< Scalar > &y)
Fits a cubic spline to the given data points.
CubicSpline()=default
Default constructor.
Scalar operator()(const double x) const
Interpolates the value at a given x-coordinate.
Scalar Interpolate(const double x) const
Interpolates the value at a given x-coordinate.