345 lines
8.6 KiB
C++
345 lines
8.6 KiB
C++
// Copyright (C) 2015 National ICT Australia (NICTA)
|
|
//
|
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
|
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
// -------------------------------------------------------------------
|
|
//
|
|
// Written by Conrad Sanderson - http://conradsanderson.id.au
|
|
|
|
|
|
//! \addtogroup fn_svds
|
|
//! @{
|
|
|
|
|
|
template<typename T1>
|
|
inline
|
|
bool
|
|
svds_helper
|
|
(
|
|
Mat<typename T1::elem_type>& U,
|
|
Col<typename T1::pod_type >& S,
|
|
Mat<typename T1::elem_type>& V,
|
|
const SpBase<typename T1::elem_type,T1>& X,
|
|
const uword k,
|
|
const typename T1::pod_type tol,
|
|
const bool calc_UV,
|
|
const typename arma_real_only<typename T1::elem_type>::result* junk = 0
|
|
)
|
|
{
|
|
arma_extra_debug_sigprint();
|
|
arma_ignore(junk);
|
|
|
|
typedef typename T1::elem_type eT;
|
|
typedef typename T1::pod_type T;
|
|
|
|
if(arma_config::arpack == false)
|
|
{
|
|
arma_stop("svds(): use of ARPACK must be enabled");
|
|
return false;
|
|
}
|
|
|
|
arma_debug_check
|
|
(
|
|
( ((void*)(&U) == (void*)(&S)) || (&U == &V) || ((void*)(&S) == (void*)(&V)) ),
|
|
"svds(): two or more output objects are the same object"
|
|
);
|
|
|
|
arma_debug_check( (tol < T(0)), "svds(): tol must be >= 0" );
|
|
|
|
const unwrap_spmat<T1> tmp(X.get_ref());
|
|
const SpMat<eT>& A = tmp.M;
|
|
|
|
const uword kk = (std::min)( (std::min)(A.n_rows, A.n_cols), k );
|
|
|
|
const T A_max = (A.n_nonzero > 0) ? T(max(abs(Col<eT>(const_cast<eT*>(A.values), A.n_nonzero, false)))) : T(0);
|
|
|
|
if(A_max == T(0))
|
|
{
|
|
// TODO: use reset instead ?
|
|
S.zeros(kk);
|
|
|
|
if(calc_UV)
|
|
{
|
|
U.eye(A.n_rows, kk);
|
|
V.eye(A.n_cols, kk);
|
|
}
|
|
}
|
|
else
|
|
{
|
|
SpMat<eT> C( (A.n_rows + A.n_cols), (A.n_rows + A.n_cols) );
|
|
|
|
SpMat<eT> B = A / A_max;
|
|
SpMat<eT> Bt = B.t();
|
|
|
|
C(0, A.n_rows, size(B) ) = B;
|
|
C(A.n_rows, 0, size(Bt)) = Bt;
|
|
|
|
Bt.reset();
|
|
B.reset();
|
|
|
|
Col<eT> eigval;
|
|
Mat<eT> eigvec;
|
|
|
|
const bool status = sp_auxlib::eigs_sym(eigval, eigvec, C, kk, "la", (tol / Datum<T>::sqrt2));
|
|
|
|
if(status == false)
|
|
{
|
|
U.reset();
|
|
S.reset();
|
|
V.reset();
|
|
|
|
return false;
|
|
}
|
|
|
|
const T A_norm = max(eigval);
|
|
|
|
const T tol2 = tol / Datum<T>::sqrt2 * A_norm;
|
|
|
|
uvec indices = find(eigval > tol2);
|
|
|
|
if(indices.n_elem > kk)
|
|
{
|
|
indices = indices.subvec(0,kk-1);
|
|
}
|
|
else
|
|
if(indices.n_elem < kk)
|
|
{
|
|
const uvec indices2 = find(abs(eigval) <= tol2);
|
|
|
|
const uword N_extra = (std::min)( indices2.n_elem, (kk - indices.n_elem) );
|
|
|
|
if(N_extra > 0) { indices = join_cols(indices, indices2.subvec(0,N_extra-1)); }
|
|
}
|
|
|
|
const uvec sorted_indices = sort_index(eigval, "descend");
|
|
|
|
S = eigval.elem(sorted_indices); S *= A_max;
|
|
|
|
if(calc_UV)
|
|
{
|
|
uvec U_row_indices(A.n_rows); for(uword i=0; i < A.n_rows; ++i) { U_row_indices[i] = i; }
|
|
uvec V_row_indices(A.n_cols); for(uword i=0; i < A.n_cols; ++i) { V_row_indices[i] = i + A.n_rows; }
|
|
|
|
U = Datum<T>::sqrt2 * eigvec(U_row_indices, sorted_indices);
|
|
V = Datum<T>::sqrt2 * eigvec(V_row_indices, sorted_indices);
|
|
}
|
|
}
|
|
|
|
if(S.n_elem < k) { arma_debug_warn("svds(): found fewer singular values than specified"); }
|
|
|
|
return true;
|
|
}
|
|
|
|
|
|
|
|
template<typename T1>
|
|
inline
|
|
bool
|
|
svds_helper
|
|
(
|
|
Mat<typename T1::elem_type>& U,
|
|
Col<typename T1::pod_type >& S,
|
|
Mat<typename T1::elem_type>& V,
|
|
const SpBase<typename T1::elem_type,T1>& X,
|
|
const uword k,
|
|
const typename T1::pod_type tol,
|
|
const bool calc_UV,
|
|
const typename arma_cx_only<typename T1::elem_type>::result* junk = 0
|
|
)
|
|
{
|
|
arma_extra_debug_sigprint();
|
|
arma_ignore(junk);
|
|
|
|
typedef typename T1::elem_type eT;
|
|
typedef typename T1::pod_type T;
|
|
|
|
if(arma_config::arpack == false)
|
|
{
|
|
arma_stop("svds(): use of ARPACK must be enabled");
|
|
return false;
|
|
}
|
|
|
|
arma_debug_check
|
|
(
|
|
( ((void*)(&U) == (void*)(&S)) || (&U == &V) || ((void*)(&S) == (void*)(&V)) ),
|
|
"svds(): two or more output objects are the same object"
|
|
);
|
|
|
|
arma_debug_check( (tol < T(0)), "svds(): tol must be >= 0" );
|
|
|
|
const unwrap_spmat<T1> tmp(X.get_ref());
|
|
const SpMat<eT>& A = tmp.M;
|
|
|
|
const uword kk = (std::min)( (std::min)(A.n_rows, A.n_cols), k );
|
|
|
|
const T A_max = (A.n_nonzero > 0) ? T(max(abs(Col<eT>(const_cast<eT*>(A.values), A.n_nonzero, false)))) : T(0);
|
|
|
|
if(A_max == T(0))
|
|
{
|
|
// TODO: use reset instead ?
|
|
S.zeros(kk);
|
|
|
|
if(calc_UV)
|
|
{
|
|
U.eye(A.n_rows, kk);
|
|
V.eye(A.n_cols, kk);
|
|
}
|
|
}
|
|
else
|
|
{
|
|
SpMat<eT> C( (A.n_rows + A.n_cols), (A.n_rows + A.n_cols) );
|
|
|
|
SpMat<eT> B = A / A_max;
|
|
SpMat<eT> Bt = B.t();
|
|
|
|
C(0, A.n_rows, size(B) ) = B;
|
|
C(A.n_rows, 0, size(Bt)) = Bt;
|
|
|
|
Bt.reset();
|
|
B.reset();
|
|
|
|
Col<eT> eigval_tmp;
|
|
Mat<eT> eigvec;
|
|
|
|
const bool status = sp_auxlib::eigs_gen(eigval_tmp, eigvec, C, kk, "lr", (tol / Datum<T>::sqrt2));
|
|
|
|
if(status == false)
|
|
{
|
|
U.reset();
|
|
S.reset();
|
|
V.reset();
|
|
arma_debug_warn("svds(): decomposition failed");
|
|
|
|
return false;
|
|
}
|
|
|
|
const Col<T> eigval = real(eigval_tmp);
|
|
|
|
const T A_norm = max(eigval);
|
|
|
|
const T tol2 = tol / Datum<T>::sqrt2 * A_norm;
|
|
|
|
uvec indices = find(eigval > tol2);
|
|
|
|
if(indices.n_elem > kk)
|
|
{
|
|
indices = indices.subvec(0,kk-1);
|
|
}
|
|
else
|
|
if(indices.n_elem < kk)
|
|
{
|
|
const uvec indices2 = find(abs(eigval) <= tol2);
|
|
|
|
const uword N_extra = (std::min)( indices2.n_elem, (kk - indices.n_elem) );
|
|
|
|
if(N_extra > 0) { indices = join_cols(indices, indices2.subvec(0,N_extra-1)); }
|
|
}
|
|
|
|
const uvec sorted_indices = sort_index(eigval, "descend");
|
|
|
|
S = eigval.elem(sorted_indices); S *= A_max;
|
|
|
|
if(calc_UV)
|
|
{
|
|
uvec U_row_indices(A.n_rows); for(uword i=0; i < A.n_rows; ++i) { U_row_indices[i] = i; }
|
|
uvec V_row_indices(A.n_cols); for(uword i=0; i < A.n_cols; ++i) { V_row_indices[i] = i + A.n_rows; }
|
|
|
|
U = Datum<T>::sqrt2 * eigvec(U_row_indices, sorted_indices);
|
|
V = Datum<T>::sqrt2 * eigvec(V_row_indices, sorted_indices);
|
|
}
|
|
}
|
|
|
|
if(S.n_elem < k) { arma_debug_warn("svds(): found fewer singular values than specified"); }
|
|
|
|
return true;
|
|
}
|
|
|
|
|
|
|
|
//! find the k largest singular values and corresponding singular vectors of sparse matrix X
|
|
template<typename T1>
|
|
inline
|
|
bool
|
|
svds
|
|
(
|
|
Mat<typename T1::elem_type>& U,
|
|
Col<typename T1::pod_type >& S,
|
|
Mat<typename T1::elem_type>& V,
|
|
const SpBase<typename T1::elem_type,T1>& X,
|
|
const uword k,
|
|
const typename T1::pod_type tol = 0.0,
|
|
const typename arma_real_or_cx_only<typename T1::elem_type>::result* junk = 0
|
|
)
|
|
{
|
|
arma_extra_debug_sigprint();
|
|
arma_ignore(junk);
|
|
|
|
const bool status = svds_helper(U, S, V, X.get_ref(), k, tol, true);
|
|
|
|
if(status == false) { arma_debug_warn("svds(): decomposition failed"); }
|
|
|
|
return status;
|
|
}
|
|
|
|
|
|
|
|
//! find the k largest singular values of sparse matrix X
|
|
template<typename T1>
|
|
inline
|
|
bool
|
|
svds
|
|
(
|
|
Col<typename T1::pod_type >& S,
|
|
const SpBase<typename T1::elem_type,T1>& X,
|
|
const uword k,
|
|
const typename T1::pod_type tol = 0.0,
|
|
const typename arma_real_or_cx_only<typename T1::elem_type>::result* junk = 0
|
|
)
|
|
{
|
|
arma_extra_debug_sigprint();
|
|
arma_ignore(junk);
|
|
|
|
Mat<typename T1::elem_type> U;
|
|
Mat<typename T1::elem_type> V;
|
|
|
|
const bool status = svds_helper(U, S, V, X.get_ref(), k, tol, false);
|
|
|
|
if(status == false) { arma_debug_warn("svds(): decomposition failed"); }
|
|
|
|
return status;
|
|
}
|
|
|
|
|
|
|
|
//! find the k largest singular values of sparse matrix X
|
|
template<typename T1>
|
|
inline
|
|
Col<typename T1::pod_type>
|
|
svds
|
|
(
|
|
const SpBase<typename T1::elem_type,T1>& X,
|
|
const uword k,
|
|
const typename T1::pod_type tol = 0.0,
|
|
const typename arma_real_or_cx_only<typename T1::elem_type>::result* junk = 0
|
|
)
|
|
{
|
|
arma_extra_debug_sigprint();
|
|
arma_ignore(junk);
|
|
|
|
Col<typename T1::pod_type> S;
|
|
|
|
Mat<typename T1::elem_type> U;
|
|
Mat<typename T1::elem_type> V;
|
|
|
|
const bool status = svds_helper(U, S, V, X.get_ref(), k, tol, false);
|
|
|
|
if(status == false) { arma_bad("svds(): decomposition failed"); }
|
|
|
|
return S;
|
|
}
|
|
|
|
|
|
|
|
//! @}
|