sablib
Loading...
Searching...
No Matches
spdiags.h
Go to the documentation of this file.
1
6
7#ifndef __SABLIB_SPDIAGS_H__
8#define __SABLIB_SPDIAGS_H__
9
10#include <algorithm>
11#include <initializer_list>
12#include <stdexcept>
13#include <vector>
14#include <Eigen/Eigen>
15
16namespace sablib {
17
28template <typename Derived>
29Eigen::SparseMatrix<typename Derived::PlainObject::Scalar>
30Spdiags(const Eigen::MatrixBase<Derived> & data, const Eigen::VectorXi & diags, const int m = -1, const int n = -1)
31{
32 using Scalar = typename Derived::PlainObject::Scalar;
33 using T = Eigen::Triplet<Scalar>;
34
35 if(diags.size() > data.rows()) {
36 throw std::invalid_argument("diags size is larger than rows of data.");
37 }
38
39 int row_size, column_size;
40 Eigen::SparseMatrix<Scalar> a;
41 std::vector<T> triplets;
42
43 row_size = (m <= 0) ? (int)data.rows() : m;
44 column_size = (n <= 0) ? (int)data.cols() : n;
45
46 a.resize(row_size, column_size);
47 triplets.resize(row_size * diags.size());
48
49 for (int k = 0; k < diags.size(); k++) {
50 int start_index = std::max(0, diags(k));
51 int end_index = std::min((int)data.cols(), column_size);
52
53 for(int i = start_index; i < end_index; i++) {
54 if(i - diags(k) < row_size && i < column_size) {
55 triplets.emplace_back(i - diags(k), i, data(k, i));
56 }
57 }
58 }
59
60 a.setFromTriplets(triplets.begin(), triplets.end());
61
62 return a;
63}
64
75template <typename Derived>
76inline Eigen::SparseMatrix<typename Derived::PlainObject::Scalar>
77Spdiags(const Eigen::MatrixBase<Derived> & data, const std::vector<int> & diags, const int m = -1, const int n = -1)
78{
79 return Spdiags(
80 data,
81 Eigen::Map<Eigen::VectorXi>((int *)diags.data(), diags.size()),
82 m, n
83 );
84}
85
96template <typename Derived>
97inline Eigen::SparseMatrix<typename Derived::PlainObject::Scalar>
98Spdiags(const Eigen::MatrixBase<Derived> & data, const std::initializer_list<int> & diags, const int m = -1, const int n = -1)
99{
100 return Spdiags(data, std::vector<int>{diags}, m, n);
101}
102
103}; // namespace sablib
104
105#endif // __IZADORI_EIGEN_SPDIAGS_H__
Eigen::SparseMatrix< typename Derived::PlainObject::Scalar > Spdiags(const Eigen::MatrixBase< Derived > &data, const Eigen::VectorXi &diags, const int m=-1, const int n=-1)
Returns a sparse matrix with the specified elements on its diagonals.
Definition spdiags.h:30