254 lines
		
	
	
		
			4.7 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			254 lines
		
	
	
		
			4.7 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| // Copyright (C) 2008-2015 National ICT Australia (NICTA)
 | |
| // 
 | |
| // This Source Code Form is subject to the terms of the Mozilla Public
 | |
| // License, v. 2.0. If a copy of the MPL was not distributed with this
 | |
| // file, You can obtain one at http://mozilla.org/MPL/2.0/.
 | |
| // -------------------------------------------------------------------
 | |
| // 
 | |
| // Written by Conrad Sanderson - http://conradsanderson.id.au
 | |
| // Written by Ryan Curtin
 | |
| 
 | |
| 
 | |
| //! \addtogroup fn_trace
 | |
| //! @{
 | |
| 
 | |
| 
 | |
| template<typename T1>
 | |
| arma_hot
 | |
| arma_warn_unused
 | |
| inline
 | |
| typename enable_if2<is_arma_type<T1>::value, typename T1::elem_type>::result
 | |
| trace(const T1& X)
 | |
|   {
 | |
|   arma_extra_debug_sigprint();
 | |
|   
 | |
|   typedef typename T1::elem_type eT;
 | |
|   
 | |
|   const Proxy<T1> A(X);
 | |
|   
 | |
|   const uword N = (std::min)(A.get_n_rows(), A.get_n_cols());
 | |
|   
 | |
|   eT val1 = eT(0);
 | |
|   eT val2 = eT(0);
 | |
|   
 | |
|   uword i,j;
 | |
|   for(i=0, j=1; j<N; i+=2, j+=2)
 | |
|     {
 | |
|     val1 += A.at(i,i);
 | |
|     val2 += A.at(j,j);
 | |
|     }
 | |
|   
 | |
|   if(i < N)
 | |
|     {
 | |
|     val1 += A.at(i,i);
 | |
|     }
 | |
|   
 | |
|   return val1 + val2;
 | |
|   }
 | |
| 
 | |
| 
 | |
| 
 | |
| template<typename T1>
 | |
| arma_hot
 | |
| arma_warn_unused
 | |
| inline
 | |
| typename T1::elem_type
 | |
| trace(const Op<T1, op_diagmat>& X)
 | |
|   {
 | |
|   arma_extra_debug_sigprint();
 | |
|   
 | |
|   typedef typename T1::elem_type eT;
 | |
|   
 | |
|   const diagmat_proxy<T1> A(X.m);
 | |
|   
 | |
|   const uword N = (std::min)(A.n_rows, A.n_cols);
 | |
|   
 | |
|   eT val = eT(0);
 | |
|   
 | |
|   for(uword i=0; i<N; ++i)
 | |
|     {
 | |
|     val += A[i];
 | |
|     }
 | |
|   
 | |
|   return val;
 | |
|   }
 | |
| 
 | |
| 
 | |
| 
 | |
| template<typename T1, typename T2>
 | |
| arma_hot
 | |
| inline
 | |
| typename T1::elem_type
 | |
| trace_mul_unwrap(const Proxy<T1>& PA, const T2& XB)
 | |
|   {
 | |
|   arma_extra_debug_sigprint();
 | |
|   
 | |
|   typedef typename T1::elem_type eT;
 | |
|   
 | |
|   const unwrap<T2> tmpB(XB);
 | |
|   
 | |
|   const Mat<eT>& B = tmpB.M;
 | |
| 
 | |
|   const uword A_n_rows = PA.get_n_rows();
 | |
|   const uword A_n_cols = PA.get_n_cols();
 | |
|   
 | |
|   const uword B_n_rows = B.n_rows;
 | |
|   const uword B_n_cols = B.n_cols;
 | |
|   
 | |
|   arma_debug_assert_mul_size(A_n_rows, A_n_cols, B_n_rows, B_n_cols, "matrix multiplication");
 | |
|   
 | |
|   const uword N = (std::min)(A_n_rows, B_n_cols);
 | |
|   
 | |
|   eT val = eT(0);
 | |
|   
 | |
|   for(uword k=0; k < N; ++k)
 | |
|     {
 | |
|     const eT* B_colptr = B.colptr(k);
 | |
|     
 | |
|     eT acc1 = eT(0);
 | |
|     eT acc2 = eT(0);
 | |
|     
 | |
|     uword j;
 | |
|     
 | |
|     for(j=1; j < A_n_cols; j+=2)
 | |
|       {
 | |
|       const uword i = (j-1);
 | |
|       
 | |
|       const eT tmp_i = B_colptr[i];
 | |
|       const eT tmp_j = B_colptr[j];
 | |
|       
 | |
|       acc1 += PA.at(k, i) * tmp_i;
 | |
|       acc2 += PA.at(k, j) * tmp_j;
 | |
|       }
 | |
|     
 | |
|     const uword i = (j-1);
 | |
|     
 | |
|     if(i < A_n_cols)
 | |
|       {
 | |
|       acc1 += PA.at(k, i) * B_colptr[i];
 | |
|       }
 | |
|     
 | |
|     val += (acc1 + acc2);
 | |
|     }
 | |
|   
 | |
|   return val;
 | |
|   }
 | |
| 
 | |
| 
 | |
| 
 | |
| //! speedup for trace(A*B), where the result of A*B is a square sized matrix
 | |
| template<typename T1, typename T2>
 | |
| arma_hot
 | |
| inline
 | |
| typename T1::elem_type
 | |
| trace_mul_proxy(const Proxy<T1>& PA, const T2& XB)
 | |
|   {
 | |
|   arma_extra_debug_sigprint();
 | |
|   
 | |
|   typedef typename T1::elem_type eT;
 | |
|   
 | |
|   const Proxy<T2> PB(XB);
 | |
|   
 | |
|   if(is_Mat<typename Proxy<T2>::stored_type>::value)
 | |
|     {
 | |
|     return trace_mul_unwrap(PA, PB.Q);
 | |
|     }
 | |
|   
 | |
|   const uword A_n_rows = PA.get_n_rows();
 | |
|   const uword A_n_cols = PA.get_n_cols();
 | |
|   
 | |
|   const uword B_n_rows = PB.get_n_rows();
 | |
|   const uword B_n_cols = PB.get_n_cols();
 | |
|   
 | |
|   arma_debug_assert_mul_size(A_n_rows, A_n_cols, B_n_rows, B_n_cols, "matrix multiplication");
 | |
|   
 | |
|   const uword N = (std::min)(A_n_rows, B_n_cols);
 | |
|   
 | |
|   eT val = eT(0);
 | |
|   
 | |
|   for(uword k=0; k < N; ++k)
 | |
|     {
 | |
|     eT acc1 = eT(0);
 | |
|     eT acc2 = eT(0);
 | |
|     
 | |
|     uword j;
 | |
|     
 | |
|     for(j=1; j < A_n_cols; j+=2)
 | |
|       {
 | |
|       const uword i = (j-1);
 | |
|       
 | |
|       const eT tmp_i = PB.at(i, k);
 | |
|       const eT tmp_j = PB.at(j, k);
 | |
|       
 | |
|       acc1 += PA.at(k, i) * tmp_i;
 | |
|       acc2 += PA.at(k, j) * tmp_j;
 | |
|       }
 | |
|     
 | |
|     const uword i = (j-1);
 | |
|     
 | |
|     if(i < A_n_cols)
 | |
|       {
 | |
|       acc1 += PA.at(k, i) * PB.at(i, k);
 | |
|       }
 | |
|     
 | |
|     val += (acc1 + acc2);
 | |
|     }
 | |
|   
 | |
|   return val;
 | |
|   }
 | |
| 
 | |
| 
 | |
| 
 | |
| //! speedup for trace(A*B), where the result of A*B is a square sized matrix
 | |
| template<typename T1, typename T2>
 | |
| arma_hot
 | |
| arma_warn_unused
 | |
| inline
 | |
| typename T1::elem_type
 | |
| trace(const Glue<T1, T2, glue_times>& X)
 | |
|   {
 | |
|   arma_extra_debug_sigprint();
 | |
|   
 | |
|   const Proxy<T1> PA(X.A);
 | |
|   
 | |
|   return (is_Mat<T2>::value) ? trace_mul_unwrap(PA, X.B) : trace_mul_proxy(PA, X.B);
 | |
|   }
 | |
| 
 | |
| 
 | |
| 
 | |
| //! trace of sparse object
 | |
| template<typename T1>
 | |
| arma_hot
 | |
| arma_warn_unused
 | |
| inline
 | |
| typename enable_if2<is_arma_sparse_type<T1>::value, typename T1::elem_type>::result
 | |
| trace(const T1& x)
 | |
|   {
 | |
|   arma_extra_debug_sigprint();
 | |
|   
 | |
|   const SpProxy<T1> p(x);
 | |
|   
 | |
|   typedef typename T1::elem_type eT;
 | |
|   
 | |
|   eT result = eT(0);
 | |
|   
 | |
|   typename SpProxy<T1>::const_iterator_type it     = p.begin();
 | |
|   typename SpProxy<T1>::const_iterator_type it_end = p.end();
 | |
|   
 | |
|   while(it != it_end)
 | |
|     {
 | |
|     if(it.row() == it.col())
 | |
|       {
 | |
|       result += (*it);
 | |
|       }
 | |
|     
 | |
|     ++it;
 | |
|     }
 | |
|   
 | |
|   return result;
 | |
|   }
 | |
| 
 | |
| 
 | |
| 
 | |
| //! @}
 | 
