123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198 |
- // This file is part of Eigen, a lightweight C++ template library
- // for linear algebra.
- //
- // Copyright (C) 2008-2014 Gael Guennebaud <gael.guennebaud@inria.fr>
- //
- // This Source Code Form is subject to the terms of the Mozilla
- // Public License v. 2.0. If a copy of the MPL was not distributed
- // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
- #ifndef EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
- #define EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
- namespace Eigen {
- namespace internal {
- // perform a pseudo in-place sparse * sparse product assuming all matrices are col major
- template<typename Lhs, typename Rhs, typename ResultType>
- static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res, const typename ResultType::RealScalar& tolerance)
- {
- // return sparse_sparse_product_with_pruning_impl2(lhs,rhs,res);
- typedef typename remove_all<Rhs>::type::Scalar RhsScalar;
- typedef typename remove_all<ResultType>::type::Scalar ResScalar;
- typedef typename remove_all<Lhs>::type::StorageIndex StorageIndex;
- // make sure to call innerSize/outerSize since we fake the storage order.
- Index rows = lhs.innerSize();
- Index cols = rhs.outerSize();
- //Index size = lhs.outerSize();
- eigen_assert(lhs.outerSize() == rhs.innerSize());
- // allocate a temporary buffer
- AmbiVector<ResScalar,StorageIndex> tempVector(rows);
- // mimics a resizeByInnerOuter:
- if(ResultType::IsRowMajor)
- res.resize(cols, rows);
- else
- res.resize(rows, cols);
-
- evaluator<Lhs> lhsEval(lhs);
- evaluator<Rhs> rhsEval(rhs);
-
- // estimate the number of non zero entries
- // given a rhs column containing Y non zeros, we assume that the respective Y columns
- // of the lhs differs in average of one non zeros, thus the number of non zeros for
- // the product of a rhs column with the lhs is X+Y where X is the average number of non zero
- // per column of the lhs.
- // Therefore, we have nnz(lhs*rhs) = nnz(lhs) + nnz(rhs)
- Index estimated_nnz_prod = lhsEval.nonZerosEstimate() + rhsEval.nonZerosEstimate();
- res.reserve(estimated_nnz_prod);
- double ratioColRes = double(estimated_nnz_prod)/(double(lhs.rows())*double(rhs.cols()));
- for (Index j=0; j<cols; ++j)
- {
- // FIXME:
- //double ratioColRes = (double(rhs.innerVector(j).nonZeros()) + double(lhs.nonZeros())/double(lhs.cols()))/double(lhs.rows());
- // let's do a more accurate determination of the nnz ratio for the current column j of res
- tempVector.init(ratioColRes);
- tempVector.setZero();
- for (typename evaluator<Rhs>::InnerIterator rhsIt(rhsEval, j); rhsIt; ++rhsIt)
- {
- // FIXME should be written like this: tmp += rhsIt.value() * lhs.col(rhsIt.index())
- tempVector.restart();
- RhsScalar x = rhsIt.value();
- for (typename evaluator<Lhs>::InnerIterator lhsIt(lhsEval, rhsIt.index()); lhsIt; ++lhsIt)
- {
- tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x;
- }
- }
- res.startVec(j);
- for (typename AmbiVector<ResScalar,StorageIndex>::Iterator it(tempVector,tolerance); it; ++it)
- res.insertBackByOuterInner(j,it.index()) = it.value();
- }
- res.finalize();
- }
- template<typename Lhs, typename Rhs, typename ResultType,
- int LhsStorageOrder = traits<Lhs>::Flags&RowMajorBit,
- int RhsStorageOrder = traits<Rhs>::Flags&RowMajorBit,
- int ResStorageOrder = traits<ResultType>::Flags&RowMajorBit>
- struct sparse_sparse_product_with_pruning_selector;
- template<typename Lhs, typename Rhs, typename ResultType>
- struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor>
- {
- typedef typename ResultType::RealScalar RealScalar;
- static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
- {
- typename remove_all<ResultType>::type _res(res.rows(), res.cols());
- internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,ResultType>(lhs, rhs, _res, tolerance);
- res.swap(_res);
- }
- };
- template<typename Lhs, typename Rhs, typename ResultType>
- struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,RowMajor>
- {
- typedef typename ResultType::RealScalar RealScalar;
- static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
- {
- // we need a col-major matrix to hold the result
- typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> SparseTemporaryType;
- SparseTemporaryType _res(res.rows(), res.cols());
- internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,SparseTemporaryType>(lhs, rhs, _res, tolerance);
- res = _res;
- }
- };
- template<typename Lhs, typename Rhs, typename ResultType>
- struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,RowMajor>
- {
- typedef typename ResultType::RealScalar RealScalar;
- static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
- {
- // let's transpose the product to get a column x column product
- typename remove_all<ResultType>::type _res(res.rows(), res.cols());
- internal::sparse_sparse_product_with_pruning_impl<Rhs,Lhs,ResultType>(rhs, lhs, _res, tolerance);
- res.swap(_res);
- }
- };
- template<typename Lhs, typename Rhs, typename ResultType>
- struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor>
- {
- typedef typename ResultType::RealScalar RealScalar;
- static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
- {
- typedef SparseMatrix<typename Lhs::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixLhs;
- typedef SparseMatrix<typename Rhs::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixRhs;
- ColMajorMatrixLhs colLhs(lhs);
- ColMajorMatrixRhs colRhs(rhs);
- internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs,ColMajorMatrixRhs,ResultType>(colLhs, colRhs, res, tolerance);
- // let's transpose the product to get a column x column product
- // typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType;
- // SparseTemporaryType _res(res.cols(), res.rows());
- // sparse_sparse_product_with_pruning_impl<Rhs,Lhs,SparseTemporaryType>(rhs, lhs, _res);
- // res = _res.transpose();
- }
- };
- template<typename Lhs, typename Rhs, typename ResultType>
- struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,RowMajor>
- {
- typedef typename ResultType::RealScalar RealScalar;
- static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
- {
- typedef SparseMatrix<typename Lhs::Scalar,RowMajor,typename Lhs::StorageIndex> RowMajorMatrixLhs;
- RowMajorMatrixLhs rowLhs(lhs);
- sparse_sparse_product_with_pruning_selector<RowMajorMatrixLhs,Rhs,ResultType,RowMajor,RowMajor>(rowLhs,rhs,res,tolerance);
- }
- };
- template<typename Lhs, typename Rhs, typename ResultType>
- struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,RowMajor>
- {
- typedef typename ResultType::RealScalar RealScalar;
- static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
- {
- typedef SparseMatrix<typename Rhs::Scalar,RowMajor,typename Lhs::StorageIndex> RowMajorMatrixRhs;
- RowMajorMatrixRhs rowRhs(rhs);
- sparse_sparse_product_with_pruning_selector<Lhs,RowMajorMatrixRhs,ResultType,RowMajor,RowMajor,RowMajor>(lhs,rowRhs,res,tolerance);
- }
- };
- template<typename Lhs, typename Rhs, typename ResultType>
- struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,ColMajor>
- {
- typedef typename ResultType::RealScalar RealScalar;
- static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
- {
- typedef SparseMatrix<typename Rhs::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixRhs;
- ColMajorMatrixRhs colRhs(rhs);
- internal::sparse_sparse_product_with_pruning_impl<Lhs,ColMajorMatrixRhs,ResultType>(lhs, colRhs, res, tolerance);
- }
- };
- template<typename Lhs, typename Rhs, typename ResultType>
- struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,ColMajor>
- {
- typedef typename ResultType::RealScalar RealScalar;
- static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
- {
- typedef SparseMatrix<typename Lhs::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixLhs;
- ColMajorMatrixLhs colLhs(lhs);
- internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs,Rhs,ResultType>(colLhs, rhs, res, tolerance);
- }
- };
- } // end namespace internal
- } // end namespace Eigen
- #endif // EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
|