sablib
Loading...
Searching...
No Matches
bspline.h
Go to the documentation of this file.
1
6
7#ifndef __SABLIB_BSPLINE_H__
8#define __SABLIB_BSPLINE_H__
9
10#include <algorithm>
11#include <type_traits>
12#include <vector>
13#include <Eigen/Eigen>
14
15namespace sablib {
16
22template <typename Scalar>
23class BSpline final
24{
25 static_assert(std::is_floating_point_v<Scalar>, "Scalar must be a floating-point type.");
26
27public:
28 BSpline() = delete;
29
36 BSpline(const int degree, const Eigen::VectorX<Scalar> & knots);
37
46 BSpline(
47 const int degree, const Eigen::VectorX<Scalar> & knots,
48 const Eigen::VectorX<Scalar> & x, const Eigen::VectorX<Scalar> & y
49 );
50
58 BSpline(
59 const int degree, const Eigen::VectorX<Scalar> & knots, const Eigen::VectorX<Scalar> & coefficients
60 );
61
67 int BasisSize() const;
68
74 const Eigen::VectorX<Scalar> Knots() const;
75
81 const Eigen::VectorX<Scalar> Coefficients() const;
82
89 const Eigen::SparseMatrix<Scalar> DesignMatrix(const Eigen::VectorX<Scalar> & x) const;
90
97 void Fit(const Eigen::VectorX<Scalar> & x, const Eigen::VectorX<Scalar> & y);
98
106 const Scalar Interpolate(const Scalar x, const Eigen::VectorX<Scalar> & coefficients) const;
107
114 const Scalar Interpolate(const Scalar x) const;
115
116private:
117 int sp_degree;
118 Eigen::VectorX<Scalar> sp_knots;
119 int basis;
120 Eigen::VectorX<Scalar> sp_coefficients;
121
128 int FindSpan(const Scalar x) const;
129
137 const Eigen::VectorX<Scalar> BasisFunctions(const int span, const Scalar x) const;
138};
139
140//
141// Implementation of constructor
142//
143template <typename Scalar>
144BSpline<Scalar>::BSpline(
145 const int degree, const Eigen::VectorX<Scalar> & knots
146) : sp_degree(degree), sp_knots(knots), basis(knots.size() - degree - 1)
147{
148}
149
150//
151// Implementation of constructor with fitting
152//
153template <typename Scalar>
154BSpline<Scalar>::BSpline(
155 const int degree, const Eigen::VectorX<Scalar> & knots,
156 const Eigen::VectorX<Scalar> & x, const Eigen::VectorX<Scalar> & y
157) : sp_degree(degree), sp_knots(knots), basis(knots.size() - degree - 1)
158{
159 Fit(x, y);
160}
161
162//
163// Implementation of constructor with coefficients
164//
165template <typename Scalar>
166BSpline<Scalar>::BSpline(
167 const int degree, const Eigen::VectorX<Scalar> & knots, const Eigen::VectorX<Scalar> & coefficients
168) : sp_degree(degree), sp_knots(knots), basis(knots.size() - degree - 1), sp_coefficients(coefficients)
169{
170}
171
172//
173// Implementation of BasisSize() method
174//
175template <typename Scalar>
177{
178 return basis;
179}
180
181//
182// Implementation of Knots() method
183//
184template <typename Scalar>
185const Eigen::VectorX<Scalar> BSpline<Scalar>::Knots() const
186{
187 return sp_knots;
188}
189
190//
191// Implementation of Coefficients() method
192//
193template <typename Scalar>
194const Eigen::VectorX<Scalar> BSpline<Scalar>::Coefficients() const
195{
196 return sp_coefficients;
197}
198
199//
200// Implementation of FindSpan() method
201//
202template <typename Scalar>
203int BSpline<Scalar>::FindSpan(const Scalar x) const
204{
205 int n = basis - 1;
206
207 if (x >= sp_knots(n + 1)) {
208 return n;
209 }
210
211 if (x <= sp_knots(sp_degree)) {
212 return sp_degree;
213 }
214
215 int low = sp_degree;
216 int high = n + 1;
217 int mid = (low + high) / 2;
218
219 while (x < sp_knots(mid) || x >= sp_knots(mid + 1)) {
220 if (x < sp_knots(mid)) {
221 high = mid;
222 }
223 else {
224 low = mid;
225 }
226
227 mid = (low + high) / 2;
228 }
229
230 return mid;
231}
232
233//
234// Implementation of BasisFunctions() method
235//
236template <typename Scalar>
237const Eigen::VectorX<Scalar> BSpline<Scalar>::BasisFunctions(const int span, const Scalar x) const
238{
239 Eigen::VectorX<Scalar> N(sp_degree + 1);
240 Eigen::VectorX<Scalar> left(sp_degree + 1);
241 Eigen::VectorX<Scalar> right(sp_degree + 1);
242
243 N(0) = 1.0;
244
245 for(int j = 1; j <= sp_degree; j++) {
246 left(j) = x - sp_knots(span + 1 - j);
247 right(j) = sp_knots(span + j) - x;
248
249 Scalar saved = 0.0;
250
251 for(int r = 0; r < j; r++) {
252 Scalar temp = N(r) / (right(r + 1) + left(j - r));
253 N(r) = saved + right(r + 1) * temp;
254 saved = left(j - r) * temp;
255 }
256
257 N(j) = saved;
258 }
259
260 return N;
261}
262
263//
264// Implementation of DesignMatrix() method
265//
266template <typename Scalar>
267const Eigen::SparseMatrix<Scalar> BSpline<Scalar>::DesignMatrix(const Eigen::VectorX<Scalar> & x) const
268{
269 int n = x.size();
270 std::vector< Eigen::Triplet<Scalar> > triplets;
271
272 for(int r = 0; r < n; r++) {
273 Scalar xi = x(r);
274 int span = FindSpan(xi);
275 Eigen::VectorX<Scalar> N = BasisFunctions(span, xi);
276
277 for(int j = 0; j <= sp_degree; j++) {
278 int col = span - sp_degree + j;
279
280 if(col >=0 && col < basis) {
281 triplets.emplace_back(r, col, N(j));
282 }
283 }
284 }
285
286 Eigen::SparseMatrix<Scalar> B(n , basis);
287 B.setFromTriplets(triplets.begin(), triplets.end());
288
289 return B;
290}
291
292//
293// Implementation of Fit() method
294//
295template <typename Scalar>
296void BSpline<Scalar>::Fit(const Eigen::VectorX<Scalar> & x, const Eigen::VectorX<Scalar> & y)
297{
298 auto B = DesignMatrix(x);
299 Eigen::SparseMatrix<Scalar> BTB = B.transpose() * B;
300 Eigen::SimplicialLDLT< Eigen::SparseMatrix<Scalar> > solver;
301
302 solver.compute(BTB);
303
304 if(solver.info() != Eigen::Success) {
305 throw std::runtime_error("BSpline::Fit(): solver calculation fails.");
306 }
307
308 sp_coefficients = solver.solve(B.transpose() * y);
309}
310
311//
312// Implementation of Interpolate() method
313//
314template<typename Scalar>
315const Scalar BSpline<Scalar>::Interpolate(const Scalar x, const Eigen::VectorX<Scalar> & coefficients) const
316{
317 int span = FindSpan(x);
318 Eigen::VectorX<Scalar> N = BasisFunctions(span, x);
319
320 Scalar y = 0.0;
321
322 for(int j = 0; j <= sp_degree; j++) {
323 int idx = span - sp_degree + j;
324
325 if(idx >=0 && idx < coefficients.size()) {
326 y += N(j) * coefficients(idx);
327 }
328 }
329
330 return y;
331}
332
333//
334// Implementation of Interpolate() method with internal coefficients
335//
336template<typename Scalar>
337inline const Scalar BSpline<Scalar>::Interpolate(const Scalar x) const
338{
339 return Interpolate(x, sp_coefficients);
340}
341
342} // namespace sablib
343
344#endif // __SABLIB_BSPLINE_H__
const Eigen::SparseMatrix< Scalar > DesignMatrix(const Eigen::VectorX< Scalar > &x) const
Constructs the design matrix (collocation matrix) for a given set of x-coordinates.
Definition bspline.h:267
int BasisSize() const
Gets the number of basis functions.
Definition bspline.h:176
void Fit(const Eigen::VectorX< Scalar > &x, const Eigen::VectorX< Scalar > &y)
Fits the B-Spline to the given data points by calculating the coefficients.
Definition bspline.h:296
const Eigen::VectorX< Scalar > Coefficients() const
Returns the internal B-Spline coefficients.
Definition bspline.h:194
const Scalar Interpolate(const Scalar x, const Eigen::VectorX< Scalar > &coefficients) const
Interpolates the value at a given x-coordinate using provided B-Spline coefficients.
Definition bspline.h:315
const Eigen::VectorX< Scalar > Knots() const
Returns the knot vector.
Definition bspline.h:185