PaStiXSupport.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2012 Désiré Nuentsa-Wakam <desire.nuentsa_wakam@inria.fr>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_PASTIXSUPPORT_H
11 #define EIGEN_PASTIXSUPPORT_H
12 
13 namespace Eigen {
14 
23 template<typename _MatrixType, bool IsStrSym = false> class PastixLU;
24 template<typename _MatrixType, int Options> class PastixLLT;
25 template<typename _MatrixType, int Options> class PastixLDLT;
26 
27 namespace internal
28 {
29 
30  template<class Pastix> struct pastix_traits;
31 
32  template<typename _MatrixType>
33  struct pastix_traits< PastixLU<_MatrixType> >
34  {
35  typedef _MatrixType MatrixType;
36  typedef typename _MatrixType::Scalar Scalar;
37  typedef typename _MatrixType::RealScalar RealScalar;
38  typedef typename _MatrixType::Index Index;
39  };
40 
41  template<typename _MatrixType, int Options>
42  struct pastix_traits< PastixLLT<_MatrixType,Options> >
43  {
44  typedef _MatrixType MatrixType;
45  typedef typename _MatrixType::Scalar Scalar;
46  typedef typename _MatrixType::RealScalar RealScalar;
47  typedef typename _MatrixType::Index Index;
48  };
49 
50  template<typename _MatrixType, int Options>
51  struct pastix_traits< PastixLDLT<_MatrixType,Options> >
52  {
53  typedef _MatrixType MatrixType;
54  typedef typename _MatrixType::Scalar Scalar;
55  typedef typename _MatrixType::RealScalar RealScalar;
56  typedef typename _MatrixType::Index Index;
57  };
58 
59  void eigen_pastix(pastix_data_t **pastix_data, int pastix_comm, int n, int *ptr, int *idx, float *vals, int *perm, int * invp, float *x, int nbrhs, int *iparm, double *dparm)
60  {
61  if (n == 0) { ptr = NULL; idx = NULL; vals = NULL; }
62  if (nbrhs == 0) {x = NULL; nbrhs=1;}
63  s_pastix(pastix_data, pastix_comm, n, ptr, idx, vals, perm, invp, x, nbrhs, iparm, dparm);
64  }
65 
66  void eigen_pastix(pastix_data_t **pastix_data, int pastix_comm, int n, int *ptr, int *idx, double *vals, int *perm, int * invp, double *x, int nbrhs, int *iparm, double *dparm)
67  {
68  if (n == 0) { ptr = NULL; idx = NULL; vals = NULL; }
69  if (nbrhs == 0) {x = NULL; nbrhs=1;}
70  d_pastix(pastix_data, pastix_comm, n, ptr, idx, vals, perm, invp, x, nbrhs, iparm, dparm);
71  }
72 
73  void eigen_pastix(pastix_data_t **pastix_data, int pastix_comm, int n, int *ptr, int *idx, std::complex<float> *vals, int *perm, int * invp, std::complex<float> *x, int nbrhs, int *iparm, double *dparm)
74  {
75  if (n == 0) { ptr = NULL; idx = NULL; vals = NULL; }
76  if (nbrhs == 0) {x = NULL; nbrhs=1;}
77  c_pastix(pastix_data, pastix_comm, n, ptr, idx, reinterpret_cast<COMPLEX*>(vals), perm, invp, reinterpret_cast<COMPLEX*>(x), nbrhs, iparm, dparm);
78  }
79 
80  void eigen_pastix(pastix_data_t **pastix_data, int pastix_comm, int n, int *ptr, int *idx, std::complex<double> *vals, int *perm, int * invp, std::complex<double> *x, int nbrhs, int *iparm, double *dparm)
81  {
82  if (n == 0) { ptr = NULL; idx = NULL; vals = NULL; }
83  if (nbrhs == 0) {x = NULL; nbrhs=1;}
84  z_pastix(pastix_data, pastix_comm, n, ptr, idx, reinterpret_cast<DCOMPLEX*>(vals), perm, invp, reinterpret_cast<DCOMPLEX*>(x), nbrhs, iparm, dparm);
85  }
86 
87  // Convert the matrix to Fortran-style Numbering
88  template <typename MatrixType>
89  void c_to_fortran_numbering (MatrixType& mat)
90  {
91  if ( !(mat.outerIndexPtr()[0]) )
92  {
93  int i;
94  for(i = 0; i <= mat.rows(); ++i)
95  ++mat.outerIndexPtr()[i];
96  for(i = 0; i < mat.nonZeros(); ++i)
97  ++mat.innerIndexPtr()[i];
98  }
99  }
100 
101  // Convert to C-style Numbering
102  template <typename MatrixType>
103  void fortran_to_c_numbering (MatrixType& mat)
104  {
105  // Check the Numbering
106  if ( mat.outerIndexPtr()[0] == 1 )
107  { // Convert to C-style numbering
108  int i;
109  for(i = 0; i <= mat.rows(); ++i)
110  --mat.outerIndexPtr()[i];
111  for(i = 0; i < mat.nonZeros(); ++i)
112  --mat.innerIndexPtr()[i];
113  }
114  }
115 }
116 
117 // This is the base class to interface with PaStiX functions.
118 // Users should not used this class directly.
119 template <class Derived>
120 class PastixBase : internal::noncopyable
121 {
122  public:
123  typedef typename internal::pastix_traits<Derived>::MatrixType _MatrixType;
124  typedef _MatrixType MatrixType;
125  typedef typename MatrixType::Scalar Scalar;
126  typedef typename MatrixType::RealScalar RealScalar;
127  typedef typename MatrixType::Index Index;
128  typedef Matrix<Scalar,Dynamic,1> Vector;
129  typedef SparseMatrix<Scalar, ColMajor> ColSpMatrix;
130 
131  public:
132 
133  PastixBase() : m_initisOk(false), m_analysisIsOk(false), m_factorizationIsOk(false), m_isInitialized(false), m_pastixdata(0), m_size(0)
134  {
135  init();
136  }
137 
138  ~PastixBase()
139  {
140  clean();
141  }
142 
147  template<typename Rhs>
148  inline const internal::solve_retval<PastixBase, Rhs>
149  solve(const MatrixBase<Rhs>& b) const
150  {
151  eigen_assert(m_isInitialized && "Pastix solver is not initialized.");
152  eigen_assert(rows()==b.rows()
153  && "PastixBase::solve(): invalid number of rows of the right hand side matrix b");
154  return internal::solve_retval<PastixBase, Rhs>(*this, b.derived());
155  }
156 
157  template<typename Rhs,typename Dest>
158  bool _solve (const MatrixBase<Rhs> &b, MatrixBase<Dest> &x) const;
159 
161  template<typename Rhs, typename DestScalar, int DestOptions, typename DestIndex>
162  void _solve_sparse(const Rhs& b, SparseMatrix<DestScalar,DestOptions,DestIndex> &dest) const
163  {
164  eigen_assert(m_factorizationIsOk && "The decomposition is not in a valid state for solving, you must first call either compute() or symbolic()/numeric()");
165  eigen_assert(rows()==b.rows());
166 
167  // we process the sparse rhs per block of NbColsAtOnce columns temporarily stored into a dense matrix.
168  static const int NbColsAtOnce = 1;
169  int rhsCols = b.cols();
170  int size = b.rows();
172  for(int k=0; k<rhsCols; k+=NbColsAtOnce)
173  {
174  int actualCols = std::min<int>(rhsCols-k, NbColsAtOnce);
175  tmp.leftCols(actualCols) = b.middleCols(k,actualCols);
176  tmp.leftCols(actualCols) = derived().solve(tmp.leftCols(actualCols));
177  dest.middleCols(k,actualCols) = tmp.leftCols(actualCols).sparseView();
178  }
179  }
180 
181  Derived& derived()
182  {
183  return *static_cast<Derived*>(this);
184  }
185  const Derived& derived() const
186  {
187  return *static_cast<const Derived*>(this);
188  }
189 
195  Array<Index,IPARM_SIZE,1>& iparm()
196  {
197  return m_iparm;
198  }
199 
204  int& iparm(int idxparam)
205  {
206  return m_iparm(idxparam);
207  }
208 
213  Array<RealScalar,IPARM_SIZE,1>& dparm()
214  {
215  return m_dparm;
216  }
217 
218 
222  double& dparm(int idxparam)
223  {
224  return m_dparm(idxparam);
225  }
226 
227  inline Index cols() const { return m_size; }
228  inline Index rows() const { return m_size; }
229 
238  ComputationInfo info() const
239  {
240  eigen_assert(m_isInitialized && "Decomposition is not initialized.");
241  return m_info;
242  }
243 
248  template<typename Rhs>
249  inline const internal::sparse_solve_retval<PastixBase, Rhs>
250  solve(const SparseMatrixBase<Rhs>& b) const
251  {
252  eigen_assert(m_isInitialized && "Pastix LU, LLT or LDLT is not initialized.");
253  eigen_assert(rows()==b.rows()
254  && "PastixBase::solve(): invalid number of rows of the right hand side matrix b");
255  return internal::sparse_solve_retval<PastixBase, Rhs>(*this, b.derived());
256  }
257 
258  protected:
259 
260  // Initialize the Pastix data structure, check the matrix
261  void init();
262 
263  // Compute the ordering and the symbolic factorization
264  void analyzePattern(ColSpMatrix& mat);
265 
266  // Compute the numerical factorization
267  void factorize(ColSpMatrix& mat);
268 
269  // Free all the data allocated by Pastix
270  void clean()
271  {
272  eigen_assert(m_initisOk && "The Pastix structure should be allocated first");
273  m_iparm(IPARM_START_TASK) = API_TASK_CLEAN;
274  m_iparm(IPARM_END_TASK) = API_TASK_CLEAN;
275  internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, 0, 0, 0, (Scalar*)0,
276  m_perm.data(), m_invp.data(), 0, 0, m_iparm.data(), m_dparm.data());
277  }
278 
279  void compute(ColSpMatrix& mat);
280 
281  int m_initisOk;
282  int m_analysisIsOk;
283  int m_factorizationIsOk;
284  bool m_isInitialized;
285  mutable ComputationInfo m_info;
286  mutable pastix_data_t *m_pastixdata; // Data structure for pastix
287  mutable int m_comm; // The MPI communicator identifier
288  mutable Matrix<int,IPARM_SIZE,1> m_iparm; // integer vector for the input parameters
289  mutable Matrix<double,DPARM_SIZE,1> m_dparm; // Scalar vector for the input parameters
290  mutable Matrix<Index,Dynamic,1> m_perm; // Permutation vector
291  mutable Matrix<Index,Dynamic,1> m_invp; // Inverse permutation vector
292  mutable int m_size; // Size of the matrix
293 };
294 
299 template <class Derived>
300 void PastixBase<Derived>::init()
301 {
302  m_size = 0;
303  m_iparm.setZero(IPARM_SIZE);
304  m_dparm.setZero(DPARM_SIZE);
305 
306  m_iparm(IPARM_MODIFY_PARAMETER) = API_NO;
307  pastix(&m_pastixdata, MPI_COMM_WORLD,
308  0, 0, 0, 0,
309  0, 0, 0, 1, m_iparm.data(), m_dparm.data());
310 
311  m_iparm[IPARM_MATRIX_VERIFICATION] = API_NO;
312  m_iparm[IPARM_VERBOSE] = 2;
313  m_iparm[IPARM_ORDERING] = API_ORDER_SCOTCH;
314  m_iparm[IPARM_INCOMPLETE] = API_NO;
315  m_iparm[IPARM_OOC_LIMIT] = 2000;
316  m_iparm[IPARM_RHS_MAKING] = API_RHS_B;
317  m_iparm(IPARM_MATRIX_VERIFICATION) = API_NO;
318 
319  m_iparm(IPARM_START_TASK) = API_TASK_INIT;
320  m_iparm(IPARM_END_TASK) = API_TASK_INIT;
321  internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, 0, 0, 0, (Scalar*)0,
322  0, 0, 0, 0, m_iparm.data(), m_dparm.data());
323 
324  // Check the returned error
325  if(m_iparm(IPARM_ERROR_NUMBER)) {
326  m_info = InvalidInput;
327  m_initisOk = false;
328  }
329  else {
330  m_info = Success;
331  m_initisOk = true;
332  }
333 }
334 
335 template <class Derived>
336 void PastixBase<Derived>::compute(ColSpMatrix& mat)
337 {
338  eigen_assert(mat.rows() == mat.cols() && "The input matrix should be squared");
339 
340  analyzePattern(mat);
341  factorize(mat);
342 
343  m_iparm(IPARM_MATRIX_VERIFICATION) = API_NO;
344  m_isInitialized = m_factorizationIsOk;
345 }
346 
347 
348 template <class Derived>
349 void PastixBase<Derived>::analyzePattern(ColSpMatrix& mat)
350 {
351  eigen_assert(m_initisOk && "The initialization of PaSTiX failed");
352 
353  // clean previous calls
354  if(m_size>0)
355  clean();
356 
357  m_size = mat.rows();
358  m_perm.resize(m_size);
359  m_invp.resize(m_size);
360 
361  m_iparm(IPARM_START_TASK) = API_TASK_ORDERING;
362  m_iparm(IPARM_END_TASK) = API_TASK_ANALYSE;
363  internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, m_size, mat.outerIndexPtr(), mat.innerIndexPtr(),
364  mat.valuePtr(), m_perm.data(), m_invp.data(), 0, 0, m_iparm.data(), m_dparm.data());
365 
366  // Check the returned error
367  if(m_iparm(IPARM_ERROR_NUMBER))
368  {
369  m_info = NumericalIssue;
370  m_analysisIsOk = false;
371  }
372  else
373  {
374  m_info = Success;
375  m_analysisIsOk = true;
376  }
377 }
378 
379 template <class Derived>
380 void PastixBase<Derived>::factorize(ColSpMatrix& mat)
381 {
382 // if(&m_cpyMat != &mat) m_cpyMat = mat;
383  eigen_assert(m_analysisIsOk && "The analysis phase should be called before the factorization phase");
384  m_iparm(IPARM_START_TASK) = API_TASK_NUMFACT;
385  m_iparm(IPARM_END_TASK) = API_TASK_NUMFACT;
386  m_size = mat.rows();
387 
388  internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, m_size, mat.outerIndexPtr(), mat.innerIndexPtr(),
389  mat.valuePtr(), m_perm.data(), m_invp.data(), 0, 0, m_iparm.data(), m_dparm.data());
390 
391  // Check the returned error
392  if(m_iparm(IPARM_ERROR_NUMBER))
393  {
394  m_info = NumericalIssue;
395  m_factorizationIsOk = false;
396  m_isInitialized = false;
397  }
398  else
399  {
400  m_info = Success;
401  m_factorizationIsOk = true;
402  m_isInitialized = true;
403  }
404 }
405 
406 /* Solve the system */
407 template<typename Base>
408 template<typename Rhs,typename Dest>
409 bool PastixBase<Base>::_solve (const MatrixBase<Rhs> &b, MatrixBase<Dest> &x) const
410 {
411  eigen_assert(m_isInitialized && "The matrix should be factorized first");
412  EIGEN_STATIC_ASSERT((Dest::Flags&RowMajorBit)==0,
413  THIS_METHOD_IS_ONLY_FOR_COLUMN_MAJOR_MATRICES);
414  int rhs = 1;
415 
416  x = b; /* on return, x is overwritten by the computed solution */
417 
418  for (int i = 0; i < b.cols(); i++){
419  m_iparm[IPARM_START_TASK] = API_TASK_SOLVE;
420  m_iparm[IPARM_END_TASK] = API_TASK_REFINE;
421 
422  internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, x.rows(), 0, 0, 0,
423  m_perm.data(), m_invp.data(), &x(0, i), rhs, m_iparm.data(), m_dparm.data());
424  }
425 
426  // Check the returned error
427  m_info = m_iparm(IPARM_ERROR_NUMBER)==0 ? Success : NumericalIssue;
428 
429  return m_iparm(IPARM_ERROR_NUMBER)==0;
430 }
431 
451 template<typename _MatrixType, bool IsStrSym>
452 class PastixLU : public PastixBase< PastixLU<_MatrixType> >
453 {
454  public:
455  typedef _MatrixType MatrixType;
456  typedef PastixBase<PastixLU<MatrixType> > Base;
457  typedef typename Base::ColSpMatrix ColSpMatrix;
458  typedef typename MatrixType::Index Index;
459 
460  public:
461  PastixLU() : Base()
462  {
463  init();
464  }
465 
466  PastixLU(const MatrixType& matrix):Base()
467  {
468  init();
469  compute(matrix);
470  }
476  void compute (const MatrixType& matrix)
477  {
478  m_structureIsUptodate = false;
479  ColSpMatrix temp;
480  grabMatrix(matrix, temp);
481  Base::compute(temp);
482  }
488  void analyzePattern(const MatrixType& matrix)
489  {
490  m_structureIsUptodate = false;
491  ColSpMatrix temp;
492  grabMatrix(matrix, temp);
493  Base::analyzePattern(temp);
494  }
495 
501  void factorize(const MatrixType& matrix)
502  {
503  ColSpMatrix temp;
504  grabMatrix(matrix, temp);
505  Base::factorize(temp);
506  }
507  protected:
508 
509  void init()
510  {
511  m_structureIsUptodate = false;
512  m_iparm(IPARM_SYM) = API_SYM_NO;
513  m_iparm(IPARM_FACTORIZATION) = API_FACT_LU;
514  }
515 
516  void grabMatrix(const MatrixType& matrix, ColSpMatrix& out)
517  {
518  if(IsStrSym)
519  out = matrix;
520  else
521  {
522  if(!m_structureIsUptodate)
523  {
524  // update the transposed structure
525  m_transposedStructure = matrix.transpose();
526 
527  // Set the elements of the matrix to zero
528  for (Index j=0; j<m_transposedStructure.outerSize(); ++j)
529  for(typename ColSpMatrix::InnerIterator it(m_transposedStructure, j); it; ++it)
530  it.valueRef() = 0.0;
531 
532  m_structureIsUptodate = true;
533  }
534 
535  out = m_transposedStructure + matrix;
536  }
537  internal::c_to_fortran_numbering(out);
538  }
539 
540  using Base::m_iparm;
541  using Base::m_dparm;
542 
543  ColSpMatrix m_transposedStructure;
544  bool m_structureIsUptodate;
545 };
546 
561 template<typename _MatrixType, int _UpLo>
562 class PastixLLT : public PastixBase< PastixLLT<_MatrixType, _UpLo> >
563 {
564  public:
565  typedef _MatrixType MatrixType;
566  typedef PastixBase<PastixLLT<MatrixType, _UpLo> > Base;
567  typedef typename Base::ColSpMatrix ColSpMatrix;
568 
569  public:
570  enum { UpLo = _UpLo };
571  PastixLLT() : Base()
572  {
573  init();
574  }
575 
576  PastixLLT(const MatrixType& matrix):Base()
577  {
578  init();
579  compute(matrix);
580  }
581 
585  void compute (const MatrixType& matrix)
586  {
587  ColSpMatrix temp;
588  grabMatrix(matrix, temp);
589  Base::compute(temp);
590  }
591 
596  void analyzePattern(const MatrixType& matrix)
597  {
598  ColSpMatrix temp;
599  grabMatrix(matrix, temp);
600  Base::analyzePattern(temp);
601  }
605  void factorize(const MatrixType& matrix)
606  {
607  ColSpMatrix temp;
608  grabMatrix(matrix, temp);
609  Base::factorize(temp);
610  }
611  protected:
612  using Base::m_iparm;
613 
614  void init()
615  {
616  m_iparm(IPARM_SYM) = API_SYM_YES;
617  m_iparm(IPARM_FACTORIZATION) = API_FACT_LLT;
618  }
619 
620  void grabMatrix(const MatrixType& matrix, ColSpMatrix& out)
621  {
622  // Pastix supports only lower, column-major matrices
623  out.template selfadjointView<Lower>() = matrix.template selfadjointView<UpLo>();
624  internal::c_to_fortran_numbering(out);
625  }
626 };
627 
642 template<typename _MatrixType, int _UpLo>
643 class PastixLDLT : public PastixBase< PastixLDLT<_MatrixType, _UpLo> >
644 {
645  public:
646  typedef _MatrixType MatrixType;
647  typedef PastixBase<PastixLDLT<MatrixType, _UpLo> > Base;
648  typedef typename Base::ColSpMatrix ColSpMatrix;
649 
650  public:
651  enum { UpLo = _UpLo };
652  PastixLDLT():Base()
653  {
654  init();
655  }
656 
657  PastixLDLT(const MatrixType& matrix):Base()
658  {
659  init();
660  compute(matrix);
661  }
662 
666  void compute (const MatrixType& matrix)
667  {
668  ColSpMatrix temp;
669  grabMatrix(matrix, temp);
670  Base::compute(temp);
671  }
672 
677  void analyzePattern(const MatrixType& matrix)
678  {
679  ColSpMatrix temp;
680  grabMatrix(matrix, temp);
681  Base::analyzePattern(temp);
682  }
686  void factorize(const MatrixType& matrix)
687  {
688  ColSpMatrix temp;
689  grabMatrix(matrix, temp);
690  Base::factorize(temp);
691  }
692 
693  protected:
694  using Base::m_iparm;
695 
696  void init()
697  {
698  m_iparm(IPARM_SYM) = API_SYM_YES;
699  m_iparm(IPARM_FACTORIZATION) = API_FACT_LDLT;
700  }
701 
702  void grabMatrix(const MatrixType& matrix, ColSpMatrix& out)
703  {
704  // Pastix supports only lower, column-major matrices
705  out.template selfadjointView<Lower>() = matrix.template selfadjointView<UpLo>();
706  internal::c_to_fortran_numbering(out);
707  }
708 };
709 
710 namespace internal {
711 
712 template<typename _MatrixType, typename Rhs>
713 struct solve_retval<PastixBase<_MatrixType>, Rhs>
714  : solve_retval_base<PastixBase<_MatrixType>, Rhs>
715 {
716  typedef PastixBase<_MatrixType> Dec;
717  EIGEN_MAKE_SOLVE_HELPERS(Dec,Rhs)
718 
719  template<typename Dest> void evalTo(Dest& dst) const
720  {
721  dec()._solve(rhs(),dst);
722  }
723 };
724 
725 template<typename _MatrixType, typename Rhs>
726 struct sparse_solve_retval<PastixBase<_MatrixType>, Rhs>
727  : sparse_solve_retval_base<PastixBase<_MatrixType>, Rhs>
728 {
729  typedef PastixBase<_MatrixType> Dec;
730  EIGEN_MAKE_SPARSE_SOLVE_HELPERS(Dec,Rhs)
731 
732  template<typename Dest> void evalTo(Dest& dst) const
733  {
734  dec()._solve_sparse(rhs(),dst);
735  }
736 };
737 
738 } // end namespace internal
739 
740 } // end namespace Eigen
741 
742 #endif