254 lines
4.7 KiB
C++
254 lines
4.7 KiB
C++
// Copyright (C) 2008-2015 National ICT Australia (NICTA)
|
|
//
|
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
|
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
// -------------------------------------------------------------------
|
|
//
|
|
// Written by Conrad Sanderson - http://conradsanderson.id.au
|
|
// Written by Ryan Curtin
|
|
|
|
|
|
//! \addtogroup fn_trace
|
|
//! @{
|
|
|
|
|
|
template<typename T1>
|
|
arma_hot
|
|
arma_warn_unused
|
|
inline
|
|
typename enable_if2<is_arma_type<T1>::value, typename T1::elem_type>::result
|
|
trace(const T1& X)
|
|
{
|
|
arma_extra_debug_sigprint();
|
|
|
|
typedef typename T1::elem_type eT;
|
|
|
|
const Proxy<T1> A(X);
|
|
|
|
const uword N = (std::min)(A.get_n_rows(), A.get_n_cols());
|
|
|
|
eT val1 = eT(0);
|
|
eT val2 = eT(0);
|
|
|
|
uword i,j;
|
|
for(i=0, j=1; j<N; i+=2, j+=2)
|
|
{
|
|
val1 += A.at(i,i);
|
|
val2 += A.at(j,j);
|
|
}
|
|
|
|
if(i < N)
|
|
{
|
|
val1 += A.at(i,i);
|
|
}
|
|
|
|
return val1 + val2;
|
|
}
|
|
|
|
|
|
|
|
template<typename T1>
|
|
arma_hot
|
|
arma_warn_unused
|
|
inline
|
|
typename T1::elem_type
|
|
trace(const Op<T1, op_diagmat>& X)
|
|
{
|
|
arma_extra_debug_sigprint();
|
|
|
|
typedef typename T1::elem_type eT;
|
|
|
|
const diagmat_proxy<T1> A(X.m);
|
|
|
|
const uword N = (std::min)(A.n_rows, A.n_cols);
|
|
|
|
eT val = eT(0);
|
|
|
|
for(uword i=0; i<N; ++i)
|
|
{
|
|
val += A[i];
|
|
}
|
|
|
|
return val;
|
|
}
|
|
|
|
|
|
|
|
template<typename T1, typename T2>
|
|
arma_hot
|
|
inline
|
|
typename T1::elem_type
|
|
trace_mul_unwrap(const Proxy<T1>& PA, const T2& XB)
|
|
{
|
|
arma_extra_debug_sigprint();
|
|
|
|
typedef typename T1::elem_type eT;
|
|
|
|
const unwrap<T2> tmpB(XB);
|
|
|
|
const Mat<eT>& B = tmpB.M;
|
|
|
|
const uword A_n_rows = PA.get_n_rows();
|
|
const uword A_n_cols = PA.get_n_cols();
|
|
|
|
const uword B_n_rows = B.n_rows;
|
|
const uword B_n_cols = B.n_cols;
|
|
|
|
arma_debug_assert_mul_size(A_n_rows, A_n_cols, B_n_rows, B_n_cols, "matrix multiplication");
|
|
|
|
const uword N = (std::min)(A_n_rows, B_n_cols);
|
|
|
|
eT val = eT(0);
|
|
|
|
for(uword k=0; k < N; ++k)
|
|
{
|
|
const eT* B_colptr = B.colptr(k);
|
|
|
|
eT acc1 = eT(0);
|
|
eT acc2 = eT(0);
|
|
|
|
uword j;
|
|
|
|
for(j=1; j < A_n_cols; j+=2)
|
|
{
|
|
const uword i = (j-1);
|
|
|
|
const eT tmp_i = B_colptr[i];
|
|
const eT tmp_j = B_colptr[j];
|
|
|
|
acc1 += PA.at(k, i) * tmp_i;
|
|
acc2 += PA.at(k, j) * tmp_j;
|
|
}
|
|
|
|
const uword i = (j-1);
|
|
|
|
if(i < A_n_cols)
|
|
{
|
|
acc1 += PA.at(k, i) * B_colptr[i];
|
|
}
|
|
|
|
val += (acc1 + acc2);
|
|
}
|
|
|
|
return val;
|
|
}
|
|
|
|
|
|
|
|
//! speedup for trace(A*B), where the result of A*B is a square sized matrix
|
|
template<typename T1, typename T2>
|
|
arma_hot
|
|
inline
|
|
typename T1::elem_type
|
|
trace_mul_proxy(const Proxy<T1>& PA, const T2& XB)
|
|
{
|
|
arma_extra_debug_sigprint();
|
|
|
|
typedef typename T1::elem_type eT;
|
|
|
|
const Proxy<T2> PB(XB);
|
|
|
|
if(is_Mat<typename Proxy<T2>::stored_type>::value)
|
|
{
|
|
return trace_mul_unwrap(PA, PB.Q);
|
|
}
|
|
|
|
const uword A_n_rows = PA.get_n_rows();
|
|
const uword A_n_cols = PA.get_n_cols();
|
|
|
|
const uword B_n_rows = PB.get_n_rows();
|
|
const uword B_n_cols = PB.get_n_cols();
|
|
|
|
arma_debug_assert_mul_size(A_n_rows, A_n_cols, B_n_rows, B_n_cols, "matrix multiplication");
|
|
|
|
const uword N = (std::min)(A_n_rows, B_n_cols);
|
|
|
|
eT val = eT(0);
|
|
|
|
for(uword k=0; k < N; ++k)
|
|
{
|
|
eT acc1 = eT(0);
|
|
eT acc2 = eT(0);
|
|
|
|
uword j;
|
|
|
|
for(j=1; j < A_n_cols; j+=2)
|
|
{
|
|
const uword i = (j-1);
|
|
|
|
const eT tmp_i = PB.at(i, k);
|
|
const eT tmp_j = PB.at(j, k);
|
|
|
|
acc1 += PA.at(k, i) * tmp_i;
|
|
acc2 += PA.at(k, j) * tmp_j;
|
|
}
|
|
|
|
const uword i = (j-1);
|
|
|
|
if(i < A_n_cols)
|
|
{
|
|
acc1 += PA.at(k, i) * PB.at(i, k);
|
|
}
|
|
|
|
val += (acc1 + acc2);
|
|
}
|
|
|
|
return val;
|
|
}
|
|
|
|
|
|
|
|
//! speedup for trace(A*B), where the result of A*B is a square sized matrix
|
|
template<typename T1, typename T2>
|
|
arma_hot
|
|
arma_warn_unused
|
|
inline
|
|
typename T1::elem_type
|
|
trace(const Glue<T1, T2, glue_times>& X)
|
|
{
|
|
arma_extra_debug_sigprint();
|
|
|
|
const Proxy<T1> PA(X.A);
|
|
|
|
return (is_Mat<T2>::value) ? trace_mul_unwrap(PA, X.B) : trace_mul_proxy(PA, X.B);
|
|
}
|
|
|
|
|
|
|
|
//! trace of sparse object
|
|
template<typename T1>
|
|
arma_hot
|
|
arma_warn_unused
|
|
inline
|
|
typename enable_if2<is_arma_sparse_type<T1>::value, typename T1::elem_type>::result
|
|
trace(const T1& x)
|
|
{
|
|
arma_extra_debug_sigprint();
|
|
|
|
const SpProxy<T1> p(x);
|
|
|
|
typedef typename T1::elem_type eT;
|
|
|
|
eT result = eT(0);
|
|
|
|
typename SpProxy<T1>::const_iterator_type it = p.begin();
|
|
typename SpProxy<T1>::const_iterator_type it_end = p.end();
|
|
|
|
while(it != it_end)
|
|
{
|
|
if(it.row() == it.col())
|
|
{
|
|
result += (*it);
|
|
}
|
|
|
|
++it;
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
|
|
|
|
//! @}
|