// Copyright (C) 2015 National ICT Australia (NICTA) // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. // ------------------------------------------------------------------- // // Written by Conrad Sanderson - http://conradsanderson.id.au //! \addtogroup fn_svds //! @{ template inline bool svds_helper ( Mat& U, Col& S, Mat& V, const SpBase& X, const uword k, const typename T1::pod_type tol, const bool calc_UV, const typename arma_real_only::result* junk = 0 ) { arma_extra_debug_sigprint(); arma_ignore(junk); typedef typename T1::elem_type eT; typedef typename T1::pod_type T; if(arma_config::arpack == false) { arma_stop("svds(): use of ARPACK must be enabled"); return false; } arma_debug_check ( ( ((void*)(&U) == (void*)(&S)) || (&U == &V) || ((void*)(&S) == (void*)(&V)) ), "svds(): two or more output objects are the same object" ); arma_debug_check( (tol < T(0)), "svds(): tol must be >= 0" ); const unwrap_spmat tmp(X.get_ref()); const SpMat& A = tmp.M; const uword kk = (std::min)( (std::min)(A.n_rows, A.n_cols), k ); const T A_max = (A.n_nonzero > 0) ? T(max(abs(Col(const_cast(A.values), A.n_nonzero, false)))) : T(0); if(A_max == T(0)) { // TODO: use reset instead ? S.zeros(kk); if(calc_UV) { U.eye(A.n_rows, kk); V.eye(A.n_cols, kk); } } else { SpMat C( (A.n_rows + A.n_cols), (A.n_rows + A.n_cols) ); SpMat B = A / A_max; SpMat Bt = B.t(); C(0, A.n_rows, size(B) ) = B; C(A.n_rows, 0, size(Bt)) = Bt; Bt.reset(); B.reset(); Col eigval; Mat eigvec; const bool status = sp_auxlib::eigs_sym(eigval, eigvec, C, kk, "la", (tol / Datum::sqrt2)); if(status == false) { U.reset(); S.reset(); V.reset(); return false; } const T A_norm = max(eigval); const T tol2 = tol / Datum::sqrt2 * A_norm; uvec indices = find(eigval > tol2); if(indices.n_elem > kk) { indices = indices.subvec(0,kk-1); } else if(indices.n_elem < kk) { const uvec indices2 = find(abs(eigval) <= tol2); const uword N_extra = (std::min)( indices2.n_elem, (kk - indices.n_elem) ); if(N_extra > 0) { indices = join_cols(indices, indices2.subvec(0,N_extra-1)); } } const uvec sorted_indices = sort_index(eigval, "descend"); S = eigval.elem(sorted_indices); S *= A_max; if(calc_UV) { uvec U_row_indices(A.n_rows); for(uword i=0; i < A.n_rows; ++i) { U_row_indices[i] = i; } uvec V_row_indices(A.n_cols); for(uword i=0; i < A.n_cols; ++i) { V_row_indices[i] = i + A.n_rows; } U = Datum::sqrt2 * eigvec(U_row_indices, sorted_indices); V = Datum::sqrt2 * eigvec(V_row_indices, sorted_indices); } } if(S.n_elem < k) { arma_debug_warn("svds(): found fewer singular values than specified"); } return true; } template inline bool svds_helper ( Mat& U, Col& S, Mat& V, const SpBase& X, const uword k, const typename T1::pod_type tol, const bool calc_UV, const typename arma_cx_only::result* junk = 0 ) { arma_extra_debug_sigprint(); arma_ignore(junk); typedef typename T1::elem_type eT; typedef typename T1::pod_type T; if(arma_config::arpack == false) { arma_stop("svds(): use of ARPACK must be enabled"); return false; } arma_debug_check ( ( ((void*)(&U) == (void*)(&S)) || (&U == &V) || ((void*)(&S) == (void*)(&V)) ), "svds(): two or more output objects are the same object" ); arma_debug_check( (tol < T(0)), "svds(): tol must be >= 0" ); const unwrap_spmat tmp(X.get_ref()); const SpMat& A = tmp.M; const uword kk = (std::min)( (std::min)(A.n_rows, A.n_cols), k ); const T A_max = (A.n_nonzero > 0) ? T(max(abs(Col(const_cast(A.values), A.n_nonzero, false)))) : T(0); if(A_max == T(0)) { // TODO: use reset instead ? S.zeros(kk); if(calc_UV) { U.eye(A.n_rows, kk); V.eye(A.n_cols, kk); } } else { SpMat C( (A.n_rows + A.n_cols), (A.n_rows + A.n_cols) ); SpMat B = A / A_max; SpMat Bt = B.t(); C(0, A.n_rows, size(B) ) = B; C(A.n_rows, 0, size(Bt)) = Bt; Bt.reset(); B.reset(); Col eigval_tmp; Mat eigvec; const bool status = sp_auxlib::eigs_gen(eigval_tmp, eigvec, C, kk, "lr", (tol / Datum::sqrt2)); if(status == false) { U.reset(); S.reset(); V.reset(); arma_debug_warn("svds(): decomposition failed"); return false; } const Col eigval = real(eigval_tmp); const T A_norm = max(eigval); const T tol2 = tol / Datum::sqrt2 * A_norm; uvec indices = find(eigval > tol2); if(indices.n_elem > kk) { indices = indices.subvec(0,kk-1); } else if(indices.n_elem < kk) { const uvec indices2 = find(abs(eigval) <= tol2); const uword N_extra = (std::min)( indices2.n_elem, (kk - indices.n_elem) ); if(N_extra > 0) { indices = join_cols(indices, indices2.subvec(0,N_extra-1)); } } const uvec sorted_indices = sort_index(eigval, "descend"); S = eigval.elem(sorted_indices); S *= A_max; if(calc_UV) { uvec U_row_indices(A.n_rows); for(uword i=0; i < A.n_rows; ++i) { U_row_indices[i] = i; } uvec V_row_indices(A.n_cols); for(uword i=0; i < A.n_cols; ++i) { V_row_indices[i] = i + A.n_rows; } U = Datum::sqrt2 * eigvec(U_row_indices, sorted_indices); V = Datum::sqrt2 * eigvec(V_row_indices, sorted_indices); } } if(S.n_elem < k) { arma_debug_warn("svds(): found fewer singular values than specified"); } return true; } //! find the k largest singular values and corresponding singular vectors of sparse matrix X template inline bool svds ( Mat& U, Col& S, Mat& V, const SpBase& X, const uword k, const typename T1::pod_type tol = 0.0, const typename arma_real_or_cx_only::result* junk = 0 ) { arma_extra_debug_sigprint(); arma_ignore(junk); const bool status = svds_helper(U, S, V, X.get_ref(), k, tol, true); if(status == false) { arma_debug_warn("svds(): decomposition failed"); } return status; } //! find the k largest singular values of sparse matrix X template inline bool svds ( Col& S, const SpBase& X, const uword k, const typename T1::pod_type tol = 0.0, const typename arma_real_or_cx_only::result* junk = 0 ) { arma_extra_debug_sigprint(); arma_ignore(junk); Mat U; Mat V; const bool status = svds_helper(U, S, V, X.get_ref(), k, tol, false); if(status == false) { arma_debug_warn("svds(): decomposition failed"); } return status; } //! find the k largest singular values of sparse matrix X template inline Col svds ( const SpBase& X, const uword k, const typename T1::pod_type tol = 0.0, const typename arma_real_or_cx_only::result* junk = 0 ) { arma_extra_debug_sigprint(); arma_ignore(junk); Col S; Mat U; Mat V; const bool status = svds_helper(U, S, V, X.get_ref(), k, tol, false); if(status == false) { arma_bad("svds(): decomposition failed"); } return S; } //! @}