SparseSparseProductWithPruning.h 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. // This file is part of Eigen, a lightweight C++ template library
  2. // for linear algebra.
  3. //
  4. // Copyright (C) 2008-2014 Gael Guennebaud <gael.guennebaud@inria.fr>
  5. //
  6. // This Source Code Form is subject to the terms of the Mozilla
  7. // Public License v. 2.0. If a copy of the MPL was not distributed
  8. // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
  9. #ifndef EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
  10. #define EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
  11. namespace Eigen {
  12. namespace internal {
  13. // perform a pseudo in-place sparse * sparse product assuming all matrices are col major
  14. template<typename Lhs, typename Rhs, typename ResultType>
  15. static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res, const typename ResultType::RealScalar& tolerance)
  16. {
  17. // return sparse_sparse_product_with_pruning_impl2(lhs,rhs,res);
  18. typedef typename remove_all<Rhs>::type::Scalar RhsScalar;
  19. typedef typename remove_all<ResultType>::type::Scalar ResScalar;
  20. typedef typename remove_all<Lhs>::type::StorageIndex StorageIndex;
  21. // make sure to call innerSize/outerSize since we fake the storage order.
  22. Index rows = lhs.innerSize();
  23. Index cols = rhs.outerSize();
  24. //Index size = lhs.outerSize();
  25. eigen_assert(lhs.outerSize() == rhs.innerSize());
  26. // allocate a temporary buffer
  27. AmbiVector<ResScalar,StorageIndex> tempVector(rows);
  28. // mimics a resizeByInnerOuter:
  29. if(ResultType::IsRowMajor)
  30. res.resize(cols, rows);
  31. else
  32. res.resize(rows, cols);
  33. evaluator<Lhs> lhsEval(lhs);
  34. evaluator<Rhs> rhsEval(rhs);
  35. // estimate the number of non zero entries
  36. // given a rhs column containing Y non zeros, we assume that the respective Y columns
  37. // of the lhs differs in average of one non zeros, thus the number of non zeros for
  38. // the product of a rhs column with the lhs is X+Y where X is the average number of non zero
  39. // per column of the lhs.
  40. // Therefore, we have nnz(lhs*rhs) = nnz(lhs) + nnz(rhs)
  41. Index estimated_nnz_prod = lhsEval.nonZerosEstimate() + rhsEval.nonZerosEstimate();
  42. res.reserve(estimated_nnz_prod);
  43. double ratioColRes = double(estimated_nnz_prod)/(double(lhs.rows())*double(rhs.cols()));
  44. for (Index j=0; j<cols; ++j)
  45. {
  46. // FIXME:
  47. //double ratioColRes = (double(rhs.innerVector(j).nonZeros()) + double(lhs.nonZeros())/double(lhs.cols()))/double(lhs.rows());
  48. // let's do a more accurate determination of the nnz ratio for the current column j of res
  49. tempVector.init(ratioColRes);
  50. tempVector.setZero();
  51. for (typename evaluator<Rhs>::InnerIterator rhsIt(rhsEval, j); rhsIt; ++rhsIt)
  52. {
  53. // FIXME should be written like this: tmp += rhsIt.value() * lhs.col(rhsIt.index())
  54. tempVector.restart();
  55. RhsScalar x = rhsIt.value();
  56. for (typename evaluator<Lhs>::InnerIterator lhsIt(lhsEval, rhsIt.index()); lhsIt; ++lhsIt)
  57. {
  58. tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x;
  59. }
  60. }
  61. res.startVec(j);
  62. for (typename AmbiVector<ResScalar,StorageIndex>::Iterator it(tempVector,tolerance); it; ++it)
  63. res.insertBackByOuterInner(j,it.index()) = it.value();
  64. }
  65. res.finalize();
  66. }
  67. template<typename Lhs, typename Rhs, typename ResultType,
  68. int LhsStorageOrder = traits<Lhs>::Flags&RowMajorBit,
  69. int RhsStorageOrder = traits<Rhs>::Flags&RowMajorBit,
  70. int ResStorageOrder = traits<ResultType>::Flags&RowMajorBit>
  71. struct sparse_sparse_product_with_pruning_selector;
  72. template<typename Lhs, typename Rhs, typename ResultType>
  73. struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor>
  74. {
  75. typedef typename ResultType::RealScalar RealScalar;
  76. static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
  77. {
  78. typename remove_all<ResultType>::type _res(res.rows(), res.cols());
  79. internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,ResultType>(lhs, rhs, _res, tolerance);
  80. res.swap(_res);
  81. }
  82. };
  83. template<typename Lhs, typename Rhs, typename ResultType>
  84. struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,RowMajor>
  85. {
  86. typedef typename ResultType::RealScalar RealScalar;
  87. static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
  88. {
  89. // we need a col-major matrix to hold the result
  90. typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> SparseTemporaryType;
  91. SparseTemporaryType _res(res.rows(), res.cols());
  92. internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,SparseTemporaryType>(lhs, rhs, _res, tolerance);
  93. res = _res;
  94. }
  95. };
  96. template<typename Lhs, typename Rhs, typename ResultType>
  97. struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,RowMajor>
  98. {
  99. typedef typename ResultType::RealScalar RealScalar;
  100. static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
  101. {
  102. // let's transpose the product to get a column x column product
  103. typename remove_all<ResultType>::type _res(res.rows(), res.cols());
  104. internal::sparse_sparse_product_with_pruning_impl<Rhs,Lhs,ResultType>(rhs, lhs, _res, tolerance);
  105. res.swap(_res);
  106. }
  107. };
  108. template<typename Lhs, typename Rhs, typename ResultType>
  109. struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor>
  110. {
  111. typedef typename ResultType::RealScalar RealScalar;
  112. static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
  113. {
  114. typedef SparseMatrix<typename Lhs::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixLhs;
  115. typedef SparseMatrix<typename Rhs::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixRhs;
  116. ColMajorMatrixLhs colLhs(lhs);
  117. ColMajorMatrixRhs colRhs(rhs);
  118. internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs,ColMajorMatrixRhs,ResultType>(colLhs, colRhs, res, tolerance);
  119. // let's transpose the product to get a column x column product
  120. // typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType;
  121. // SparseTemporaryType _res(res.cols(), res.rows());
  122. // sparse_sparse_product_with_pruning_impl<Rhs,Lhs,SparseTemporaryType>(rhs, lhs, _res);
  123. // res = _res.transpose();
  124. }
  125. };
  126. template<typename Lhs, typename Rhs, typename ResultType>
  127. struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,RowMajor>
  128. {
  129. typedef typename ResultType::RealScalar RealScalar;
  130. static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
  131. {
  132. typedef SparseMatrix<typename Lhs::Scalar,RowMajor,typename Lhs::StorageIndex> RowMajorMatrixLhs;
  133. RowMajorMatrixLhs rowLhs(lhs);
  134. sparse_sparse_product_with_pruning_selector<RowMajorMatrixLhs,Rhs,ResultType,RowMajor,RowMajor>(rowLhs,rhs,res,tolerance);
  135. }
  136. };
  137. template<typename Lhs, typename Rhs, typename ResultType>
  138. struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,RowMajor>
  139. {
  140. typedef typename ResultType::RealScalar RealScalar;
  141. static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
  142. {
  143. typedef SparseMatrix<typename Rhs::Scalar,RowMajor,typename Lhs::StorageIndex> RowMajorMatrixRhs;
  144. RowMajorMatrixRhs rowRhs(rhs);
  145. sparse_sparse_product_with_pruning_selector<Lhs,RowMajorMatrixRhs,ResultType,RowMajor,RowMajor,RowMajor>(lhs,rowRhs,res,tolerance);
  146. }
  147. };
  148. template<typename Lhs, typename Rhs, typename ResultType>
  149. struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,ColMajor>
  150. {
  151. typedef typename ResultType::RealScalar RealScalar;
  152. static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
  153. {
  154. typedef SparseMatrix<typename Rhs::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixRhs;
  155. ColMajorMatrixRhs colRhs(rhs);
  156. internal::sparse_sparse_product_with_pruning_impl<Lhs,ColMajorMatrixRhs,ResultType>(lhs, colRhs, res, tolerance);
  157. }
  158. };
  159. template<typename Lhs, typename Rhs, typename ResultType>
  160. struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,ColMajor>
  161. {
  162. typedef typename ResultType::RealScalar RealScalar;
  163. static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
  164. {
  165. typedef SparseMatrix<typename Lhs::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixLhs;
  166. ColMajorMatrixLhs colLhs(lhs);
  167. internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs,Rhs,ResultType>(colLhs, rhs, res, tolerance);
  168. }
  169. };
  170. } // end namespace internal
  171. } // end namespace Eigen
  172. #endif // EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H