// Copyright (C) 2008-2015 National ICT Australia (NICTA) // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. // ------------------------------------------------------------------- // // Written by Conrad Sanderson - http://conradsanderson.id.au // Written by Ryan Curtin //! \addtogroup fn_trace //! @{ template arma_hot arma_warn_unused inline typename enable_if2::value, typename T1::elem_type>::result trace(const T1& X) { arma_extra_debug_sigprint(); typedef typename T1::elem_type eT; const Proxy A(X); const uword N = (std::min)(A.get_n_rows(), A.get_n_cols()); eT val1 = eT(0); eT val2 = eT(0); uword i,j; for(i=0, j=1; j arma_hot arma_warn_unused inline typename T1::elem_type trace(const Op& X) { arma_extra_debug_sigprint(); typedef typename T1::elem_type eT; const diagmat_proxy A(X.m); const uword N = (std::min)(A.n_rows, A.n_cols); eT val = eT(0); for(uword i=0; i arma_hot inline typename T1::elem_type trace_mul_unwrap(const Proxy& PA, const T2& XB) { arma_extra_debug_sigprint(); typedef typename T1::elem_type eT; const unwrap tmpB(XB); const Mat& B = tmpB.M; const uword A_n_rows = PA.get_n_rows(); const uword A_n_cols = PA.get_n_cols(); const uword B_n_rows = B.n_rows; const uword B_n_cols = B.n_cols; arma_debug_assert_mul_size(A_n_rows, A_n_cols, B_n_rows, B_n_cols, "matrix multiplication"); const uword N = (std::min)(A_n_rows, B_n_cols); eT val = eT(0); for(uword k=0; k < N; ++k) { const eT* B_colptr = B.colptr(k); eT acc1 = eT(0); eT acc2 = eT(0); uword j; for(j=1; j < A_n_cols; j+=2) { const uword i = (j-1); const eT tmp_i = B_colptr[i]; const eT tmp_j = B_colptr[j]; acc1 += PA.at(k, i) * tmp_i; acc2 += PA.at(k, j) * tmp_j; } const uword i = (j-1); if(i < A_n_cols) { acc1 += PA.at(k, i) * B_colptr[i]; } val += (acc1 + acc2); } return val; } //! speedup for trace(A*B), where the result of A*B is a square sized matrix template arma_hot inline typename T1::elem_type trace_mul_proxy(const Proxy& PA, const T2& XB) { arma_extra_debug_sigprint(); typedef typename T1::elem_type eT; const Proxy PB(XB); if(is_Mat::stored_type>::value) { return trace_mul_unwrap(PA, PB.Q); } const uword A_n_rows = PA.get_n_rows(); const uword A_n_cols = PA.get_n_cols(); const uword B_n_rows = PB.get_n_rows(); const uword B_n_cols = PB.get_n_cols(); arma_debug_assert_mul_size(A_n_rows, A_n_cols, B_n_rows, B_n_cols, "matrix multiplication"); const uword N = (std::min)(A_n_rows, B_n_cols); eT val = eT(0); for(uword k=0; k < N; ++k) { eT acc1 = eT(0); eT acc2 = eT(0); uword j; for(j=1; j < A_n_cols; j+=2) { const uword i = (j-1); const eT tmp_i = PB.at(i, k); const eT tmp_j = PB.at(j, k); acc1 += PA.at(k, i) * tmp_i; acc2 += PA.at(k, j) * tmp_j; } const uword i = (j-1); if(i < A_n_cols) { acc1 += PA.at(k, i) * PB.at(i, k); } val += (acc1 + acc2); } return val; } //! speedup for trace(A*B), where the result of A*B is a square sized matrix template arma_hot arma_warn_unused inline typename T1::elem_type trace(const Glue& X) { arma_extra_debug_sigprint(); const Proxy PA(X.A); return (is_Mat::value) ? trace_mul_unwrap(PA, X.B) : trace_mul_proxy(PA, X.B); } //! trace of sparse object template arma_hot arma_warn_unused inline typename enable_if2::value, typename T1::elem_type>::result trace(const T1& x) { arma_extra_debug_sigprint(); const SpProxy p(x); typedef typename T1::elem_type eT; eT result = eT(0); typename SpProxy::const_iterator_type it = p.begin(); typename SpProxy::const_iterator_type it_end = p.end(); while(it != it_end) { if(it.row() == it.col()) { result += (*it); } ++it; } return result; } //! @}