// Copyright (C) 2014-2015 National ICT Australia (NICTA) // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. // ------------------------------------------------------------------- // // Written by Conrad Sanderson - http://conradsanderson.id.au //! \addtogroup gmm_diag //! @{ namespace gmm_priv { template inline gmm_diag::~gmm_diag() { arma_extra_debug_sigprint_this(this); arma_type_check(( (is_same_type::value == false) && (is_same_type::value == false) )); } template inline gmm_diag::gmm_diag() { arma_extra_debug_sigprint_this(this); } template inline gmm_diag::gmm_diag(const gmm_diag& x) { arma_extra_debug_sigprint_this(this); init(x); } template inline const gmm_diag& gmm_diag::operator=(const gmm_diag& x) { arma_extra_debug_sigprint(); init(x); return *this; } template inline gmm_diag::gmm_diag(const uword in_n_dims, const uword in_n_gaus) { arma_extra_debug_sigprint_this(this); init(in_n_dims, in_n_gaus); } template inline void gmm_diag::reset() { arma_extra_debug_sigprint(); mah_aux.reset(); init(0, 0); } template inline void gmm_diag::reset(const uword in_n_dims, const uword in_n_gaus) { arma_extra_debug_sigprint(); mah_aux.reset(); init(in_n_dims, in_n_gaus); } template template inline void gmm_diag::set_params(const Base& in_means_expr, const Base& in_dcovs_expr, const Base& in_hefts_expr) { arma_extra_debug_sigprint(); const unwrap tmp1(in_means_expr.get_ref()); const unwrap tmp2(in_dcovs_expr.get_ref()); const unwrap tmp3(in_hefts_expr.get_ref()); const Mat& in_means = tmp1.M; const Mat& in_dcovs = tmp2.M; const Mat& in_hefts = tmp3.M; arma_debug_check ( (size(in_means) != size(in_dcovs)) || (in_hefts.n_cols != in_means.n_cols) || (in_hefts.n_rows != 1), "gmm_diag::set_params(): given parameters have inconsistent and/or wrong sizes" ); arma_debug_check( (in_means.is_finite() == false), "gmm_diag::set_params(): given means have non-finite values" ); arma_debug_check( (in_dcovs.is_finite() == false), "gmm_diag::set_params(): given dcovs have non-finite values" ); arma_debug_check( (in_hefts.is_finite() == false), "gmm_diag::set_params(): given hefts have non-finite values" ); arma_debug_check( (any(vectorise(in_dcovs) <= eT(0))), "gmm_diag::set_params(): given dcovs have negative or zero values" ); arma_debug_check( (any(vectorise(in_hefts) < eT(0))), "gmm_diag::set_params(): given hefts have negative values" ); const eT s = accu(in_hefts); arma_debug_check( ((s < (eT(1) - Datum::eps)) || (s > (eT(1) + Datum::eps))), "gmm_diag::set_params(): sum of given hefts is not 1" ); access::rw(means) = in_means; access::rw(dcovs) = in_dcovs; access::rw(hefts) = in_hefts; init_constants(); } template template inline void gmm_diag::set_means(const Base& in_means_expr) { arma_extra_debug_sigprint(); const unwrap tmp(in_means_expr.get_ref()); const Mat& in_means = tmp.M; arma_debug_check( (size(in_means) != size(means)), "gmm_diag::set_means(): given means have incompatible size" ); arma_debug_check( (in_means.is_finite() == false), "gmm_diag::set_means(): given means have non-finite values" ); access::rw(means) = in_means; } template template inline void gmm_diag::set_dcovs(const Base& in_dcovs_expr) { arma_extra_debug_sigprint(); const unwrap tmp(in_dcovs_expr.get_ref()); const Mat& in_dcovs = tmp.M; arma_debug_check( (size(in_dcovs) != size(dcovs)), "gmm_diag::set_dcovs(): given dcovs have incompatible size" ); arma_debug_check( (in_dcovs.is_finite() == false), "gmm_diag::set_dcovs(): given dcovs have non-finite values" ); arma_debug_check( (any(vectorise(in_dcovs) <= eT(0))), "gmm_diag::set_dcovs(): given dcovs have negative or zero values" ); access::rw(dcovs) = in_dcovs; init_constants(); } template template inline void gmm_diag::set_hefts(const Base& in_hefts_expr) { arma_extra_debug_sigprint(); const unwrap tmp(in_hefts_expr.get_ref()); const Mat& in_hefts = tmp.M; arma_debug_check( (size(in_hefts) != size(hefts)), "gmm_diag::set_hefts(): given hefts have incompatible size" ); arma_debug_check( (in_hefts.is_finite() == false), "gmm_diag::set_hefts(): given hefts have non-finite values" ); arma_debug_check( (any(vectorise(in_hefts) < eT(0))), "gmm_diag::set_hefts(): given hefts have negative values" ); const eT s = accu(in_hefts); arma_debug_check( ((s < (eT(1) - Datum::eps)) || (s > (eT(1) + Datum::eps))), "gmm_diag::set_hefts(): sum of given hefts is not 1" ); // make sure all hefts are positive and non-zero const eT* in_hefts_mem = in_hefts.memptr(); eT* hefts_mem = access::rw(hefts).memptr(); for(uword i=0; i < hefts.n_elem; ++i) { hefts_mem[i] = (std::max)( in_hefts_mem[i], std::numeric_limits::min() ); } access::rw(hefts) /= accu(hefts); log_hefts = log(hefts); } template inline uword gmm_diag::n_dims() const { return means.n_rows; } template inline uword gmm_diag::n_gaus() const { return means.n_cols; } template inline bool gmm_diag::load(const std::string name) { arma_extra_debug_sigprint(); Cube Q; bool status = Q.load(name, arma_binary); if( (status == false) || (Q.n_slices != 2) ) { reset(); return false; } if( (Q.n_rows < 2) || (Q.n_cols < 1) ) { reset(); return true; } access::rw(hefts) = Q.slice(0).row(0); access::rw(means) = Q.slice(0).submat(1, 0, Q.n_rows-1, Q.n_cols-1); access::rw(dcovs) = Q.slice(1).submat(1, 0, Q.n_rows-1, Q.n_cols-1); init_constants(); return true; } template inline bool gmm_diag::save(const std::string name) const { arma_extra_debug_sigprint(); Cube Q(means.n_rows + 1, means.n_cols, 2); if(Q.n_elem > 0) { Q.slice(0).row(0) = hefts; Q.slice(1).row(0).zeros(); // reserved for future use Q.slice(0).submat(1, 0, size(means)) = means; Q.slice(1).submat(1, 0, size(dcovs)) = dcovs; } const bool status = Q.save(name, arma_binary); return status; } template inline Col gmm_diag::generate() const { arma_extra_debug_sigprint(); const uword N_dims = means.n_rows; const uword N_gaus = means.n_cols; Col out( (N_gaus > 0) ? N_dims : uword(0) ); if(N_gaus > 0) { const double val = randu(); double csum = double(0); uword gaus_id = 0; for(uword j=0; j < N_gaus; ++j) { csum += hefts[j]; if(val <= csum) { gaus_id = j; break; } } out = randn< Col >(N_dims); out %= sqrt(dcovs.col(gaus_id)); out += means.col(gaus_id); } return out; } template inline Mat gmm_diag::generate(const uword N_vec) const { arma_extra_debug_sigprint(); const uword N_dims = means.n_rows; const uword N_gaus = means.n_cols; Mat out( ( (N_gaus > 0) ? N_dims : uword(0) ), N_vec ); if(N_gaus > 0) { const eT* hefts_mem = hefts.memptr(); for(uword i=0; i < N_vec; ++i) { const double val = randu(); double csum = double(0); uword gaus_id = 0; for(uword j=0; j < N_gaus; ++j) { csum += hefts_mem[j]; if(val <= csum) { gaus_id = j; break; } } subview_col out_col = out.col(i); out_col = randn< Col >(N_dims); out_col %= sqrt(dcovs.col(gaus_id)); out_col += means.col(gaus_id); } } return out; } template template inline eT gmm_diag::log_p(const T1& expr, const gmm_empty_arg& junk1, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == true))>::result* junk2) const { arma_extra_debug_sigprint(); arma_ignore(junk1); arma_ignore(junk2); const quasi_unwrap tmp(expr); arma_debug_check( (tmp.M.n_rows != means.n_rows), "gmm_diag::log_p(): incompatible dimensions" ); return internal_scalar_log_p( tmp.M.memptr() ); } template template inline eT gmm_diag::log_p(const T1& expr, const uword gaus_id, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == true))>::result* junk2) const { arma_extra_debug_sigprint(); arma_ignore(junk2); const quasi_unwrap tmp(expr); arma_debug_check( (tmp.M.n_rows != means.n_rows), "gmm_diag::log_p(): incompatible dimensions" ); arma_debug_check( (gaus_id >= means.n_cols), "gmm_diag::log_p(): specified gaussian is out of range" ); return internal_scalar_log_p( tmp.M.memptr(), gaus_id ); } template template inline Row gmm_diag::log_p(const T1& expr, const gmm_empty_arg& junk1, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == false))>::result* junk2) const { arma_extra_debug_sigprint(); arma_ignore(junk1); arma_ignore(junk2); if(is_subview::value) { const subview& X = reinterpret_cast< const subview& >(expr); return internal_vec_log_p(X); } else { const unwrap tmp(expr); const Mat& X = tmp.M; return internal_vec_log_p(X); } } template template inline Row gmm_diag::log_p(const T1& expr, const uword gaus_id, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == false))>::result* junk2) const { arma_extra_debug_sigprint(); arma_ignore(junk2); if(is_subview::value) { const subview& X = reinterpret_cast< const subview& >(expr); return internal_vec_log_p(X, gaus_id); } else { const unwrap tmp(expr); const Mat& X = tmp.M; return internal_vec_log_p(X, gaus_id); } } template template inline eT gmm_diag::avg_log_p(const Base& expr) const { arma_extra_debug_sigprint(); if(is_subview::value) { const subview& X = reinterpret_cast< const subview& >( expr.get_ref() ); return internal_avg_log_p(X); } else { const unwrap tmp(expr.get_ref()); const Mat& X = tmp.M; return internal_avg_log_p(X); } } template template inline eT gmm_diag::avg_log_p(const Base& expr, const uword gaus_id) const { arma_extra_debug_sigprint(); if(is_subview::value) { const subview& X = reinterpret_cast< const subview& >( expr.get_ref() ); return internal_avg_log_p(X, gaus_id); } else { const unwrap tmp(expr.get_ref()); const Mat& X = tmp.M; return internal_avg_log_p(X, gaus_id); } } template template inline uword gmm_diag::assign(const T1& expr, const gmm_dist_mode& dist, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == true))>::result* junk) const { arma_extra_debug_sigprint(); arma_ignore(junk); if(is_subview_col::value) { const subview_col& X = reinterpret_cast< const subview_col& >(expr); return internal_scalar_assign(X, dist); } else { const unwrap tmp(expr); const Mat& X = tmp.M; return internal_scalar_assign(X, dist); } } template template inline urowvec gmm_diag::assign(const T1& expr, const gmm_dist_mode& dist, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == false))>::result* junk) const { arma_extra_debug_sigprint(); arma_ignore(junk); urowvec out; if(is_subview::value) { const subview& X = reinterpret_cast< const subview& >(expr); internal_vec_assign(out, X, dist); } else { const unwrap tmp(expr); const Mat& X = tmp.M; internal_vec_assign(out, X, dist); } return out; } template template inline urowvec gmm_diag::raw_hist(const Base& expr, const gmm_dist_mode& dist_mode) const { arma_extra_debug_sigprint(); const unwrap tmp(expr.get_ref()); const Mat& X = tmp.M; arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::raw_hist(): incompatible dimensions" ); arma_debug_check( ((dist_mode != eucl_dist) && (dist_mode != prob_dist)), "gmm_diag::raw_hist(): unsupported distance mode" ); urowvec hist; internal_raw_hist(hist, X, dist_mode); return hist; } template template inline Row gmm_diag::norm_hist(const Base& expr, const gmm_dist_mode& dist_mode) const { arma_extra_debug_sigprint(); const unwrap tmp(expr.get_ref()); const Mat& X = tmp.M; arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::norm_hist(): incompatible dimensions" ); arma_debug_check( ((dist_mode != eucl_dist) && (dist_mode != prob_dist)), "gmm_diag::norm_hist(): unsupported distance mode" ); urowvec hist; internal_raw_hist(hist, X, dist_mode); const uword hist_n_elem = hist.n_elem; const uword* hist_mem = hist.memptr(); eT acc = eT(0); for(uword i=0; i out(hist_n_elem); eT* out_mem = out.memptr(); for(uword i=0; i template inline bool gmm_diag::learn ( const Base& data, const uword N_gaus, const gmm_dist_mode& dist_mode, const gmm_seed_mode& seed_mode, const uword km_iter, const uword em_iter, const eT var_floor, const bool print_mode ) { arma_extra_debug_sigprint(); const bool dist_mode_ok = (dist_mode == eucl_dist) || (dist_mode == maha_dist); const bool seed_mode_ok = \ (seed_mode == keep_existing) || (seed_mode == static_subset) || (seed_mode == static_spread) || (seed_mode == random_subset) || (seed_mode == random_spread); arma_debug_check( (dist_mode_ok == false), "gmm_diag::learn(): dist_mode must be eucl_dist or maha_dist" ); arma_debug_check( (seed_mode_ok == false), "gmm_diag::learn(): unknown seed_mode" ); arma_debug_check( (var_floor < eT(0) ), "gmm_diag::learn(): variance floor is negative" ); const unwrap tmp_X(data.get_ref()); const Mat& X = tmp_X.M; if(X.is_empty() ) { arma_debug_warn("gmm_diag::learn(): given matrix is empty" ); return false; } if(X.is_finite() == false) { arma_debug_warn("gmm_diag::learn(): given matrix has non-finite values"); return false; } if(N_gaus == 0) { reset(); return true; } if(dist_mode == maha_dist) { mah_aux = var(X,1,1); const uword mah_aux_n_elem = mah_aux.n_elem; eT* mah_aux_mem = mah_aux.memptr(); for(uword i=0; i < mah_aux_n_elem; ++i) { const eT val = mah_aux_mem[i]; mah_aux_mem[i] = ((val != eT(0)) && arma_isfinite(val)) ? eT(1) / val : eT(1); } } // copy current model, in case of failure by k-means and/or EM const gmm_diag orig = (*this); // initial means if(seed_mode == keep_existing) { if(means.is_empty() ) { arma_debug_warn("gmm_diag::learn(): no existing means" ); return false; } if(X.n_rows != means.n_rows) { arma_debug_warn("gmm_diag::learn(): dimensionality mismatch"); return false; } // TODO: also check for number of vectors? } else { if(X.n_cols < N_gaus) { arma_debug_warn("gmm_diag::learn(): number of vectors is less than number of gaussians"); return false; } reset(X.n_rows, N_gaus); if(print_mode) { get_stream_err2() << "gmm_diag::learn(): generating initial means\n"; } if(dist_mode == eucl_dist) { generate_initial_means<1>(X, seed_mode); } else if(dist_mode == maha_dist) { generate_initial_means<2>(X, seed_mode); } } // k-means if(km_iter > 0) { const arma_ostream_state stream_state(get_stream_err2()); bool status = false; if(dist_mode == eucl_dist) { status = km_iterate<1>(X, km_iter, print_mode); } else if(dist_mode == maha_dist) { status = km_iterate<2>(X, km_iter, print_mode); } stream_state.restore(get_stream_err2()); if(status == false) { arma_debug_warn("gmm_diag::learn(): k-means algorithm failed; not enough data, or too many gaussians requested"); init(orig); return false; } } // initial dcovs const eT vfloor = (eT(var_floor) > eT(0)) ? eT(var_floor) : std::numeric_limits::min(); if(seed_mode != keep_existing) { if(print_mode) { get_stream_err2() << "gmm_diag::learn(): generating initial covariances\n"; } if(dist_mode == eucl_dist) { generate_initial_dcovs_and_hefts<1>(X, vfloor); } else if(dist_mode == maha_dist) { generate_initial_dcovs_and_hefts<2>(X, vfloor); } } // EM algorithm if(em_iter > 0) { const arma_ostream_state stream_state(get_stream_err2()); const bool status = em_iterate(X, em_iter, vfloor, print_mode); stream_state.restore(get_stream_err2()); if(status == false) { arma_debug_warn("gmm_diag::learn(): EM algorithm failed"); init(orig); return false; } } mah_aux.reset(); init_constants(); return true; } // // // template inline void gmm_diag::init(const gmm_diag& x) { arma_extra_debug_sigprint(); gmm_diag& t = *this; if(&t != &x) { access::rw(t.means) = x.means; access::rw(t.dcovs) = x.dcovs; access::rw(t.hefts) = x.hefts; init_constants(); } } template inline void gmm_diag::init(const uword in_n_dims, const uword in_n_gaus) { arma_extra_debug_sigprint(); access::rw(means).zeros(in_n_dims, in_n_gaus); access::rw(dcovs).ones(in_n_dims, in_n_gaus); access::rw(hefts).set_size(in_n_gaus); access::rw(hefts).fill(eT(1) / eT(in_n_gaus)); init_constants(); } template inline void gmm_diag::init_constants() { arma_extra_debug_sigprint(); const uword N_dims = means.n_rows; const uword N_gaus = means.n_cols; const eT tmp = (eT(N_dims)/eT(2)) * std::log(eT(2) * Datum::pi); log_det_etc.set_size(N_gaus); for(uword i=0; i::min() ); } log_hefts = log(hefts); } template inline umat gmm_diag::internal_gen_boundaries(const uword N) const { arma_extra_debug_sigprint(); #if defined(_OPENMP) // const uword n_cores = 0; const uword n_cores = uword(omp_get_num_procs()); const uword n_threads = (n_cores > 0) ? ( (n_cores <= N) ? n_cores : 1 ) : 1; #else // static const uword n_cores = 0; static const uword n_threads = 1; #endif // get_stream_err2() << "gmm_diag::internal_gen_boundaries(): n_cores: " << n_cores << '\n'; // get_stream_err2() << "gmm_diag::internal_gen_boundaries(): n_threads: " << n_threads << '\n'; umat boundaries(2, n_threads); if(N > 0) { const uword chunk_size = N / n_threads; uword count = 0; for(uword t=0; t arma_hot inline eT gmm_diag::internal_scalar_log_p(const eT* x) const { arma_extra_debug_sigprint(); const eT* log_hefts_mem = log_hefts.mem; const uword N_gaus = means.n_cols; if(N_gaus > 0) { eT log_sum = internal_scalar_log_p(x, 0) + log_hefts_mem[0]; for(uword g=1; g < N_gaus; ++g) { const eT tmp = internal_scalar_log_p(x, g) + log_hefts_mem[g]; log_sum = log_add_exp(log_sum, tmp); } return log_sum; } else { return -Datum::inf; } } template arma_hot inline eT gmm_diag::internal_scalar_log_p(const eT* x, const uword g) const { arma_extra_debug_sigprint(); const eT* mean = means.colptr(g); const eT* dcov = dcovs.colptr(g); const uword N_dims = means.n_rows; eT val_i = eT(0); eT val_j = eT(0); uword i,j; for(i=0, j=1; j template inline Row gmm_diag::internal_vec_log_p(const T1& X) const { arma_extra_debug_sigprint(); arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::log_p(): incompatible dimensions" ); const uword N = X.n_cols; Row out(N); if(N > 0) { #if defined(_OPENMP) { const arma_omp_state save_omp_state; const umat boundaries = internal_gen_boundaries(N); const uword n_threads = boundaries.n_cols; #pragma omp parallel for for(uword t=0; t < n_threads; ++t) { const uword start_index = boundaries.at(0,t); const uword end_index = boundaries.at(1,t); eT* out_mem = out.memptr(); for(uword i=start_index; i <= end_index; ++i) { out_mem[i] = internal_scalar_log_p( X.colptr(i) ); } } } #else { eT* out_mem = out.memptr(); for(uword i=0; i < N; ++i) { out_mem[i] = internal_scalar_log_p( X.colptr(i) ); } } #endif } return out; } template template inline Row gmm_diag::internal_vec_log_p(const T1& X, const uword gaus_id) const { arma_extra_debug_sigprint(); arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::log_p(): incompatible dimensions" ); arma_debug_check( (gaus_id >= means.n_cols), "gmm_diag::log_p(): gaus_id is out of range" ); const uword N = X.n_cols; Row out(N); if(N > 0) { #if defined(_OPENMP) { const arma_omp_state save_omp_state; const umat boundaries = internal_gen_boundaries(N); const uword n_threads = boundaries.n_cols; #pragma omp parallel for for(uword t=0; t < n_threads; ++t) { const uword start_index = boundaries.at(0,t); const uword end_index = boundaries.at(1,t); eT* out_mem = out.memptr(); for(uword i=start_index; i <= end_index; ++i) { out_mem[i] = internal_scalar_log_p( X.colptr(i), gaus_id ); } } } #else { eT* out_mem = out.memptr(); for(uword i=0; i < N; ++i) { out_mem[i] = internal_scalar_log_p( X.colptr(i), gaus_id ); } } #endif } return out; } template template inline eT gmm_diag::internal_avg_log_p(const T1& X) const { arma_extra_debug_sigprint(); arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::avg_log_p(): incompatible dimensions" ); const uword N = X.n_cols; if(N == 0) { return (-Datum::inf); } #if defined(_OPENMP) { const arma_omp_state save_omp_state; const umat boundaries = internal_gen_boundaries(N); const uword n_threads = boundaries.n_cols; field< running_mean_scalar > t_running_means(n_threads); #pragma omp parallel for for(uword t=0; t < n_threads; ++t) { const uword start_index = boundaries.at(0,t); const uword end_index = boundaries.at(1,t); running_mean_scalar& current_running_mean = t_running_means[t]; for(uword i=start_index; i <= end_index; ++i) { current_running_mean( internal_scalar_log_p( X.colptr(i) ) ); } } eT avg = eT(0); for(uword t=0; t < n_threads; ++t) { running_mean_scalar& current_running_mean = t_running_means[t]; const eT w = eT(current_running_mean.count()) / eT(N); avg += w * current_running_mean.mean(); } return avg; } #else { running_mean_scalar running_mean; for(uword i=0; i template inline eT gmm_diag::internal_avg_log_p(const T1& X, const uword gaus_id) const { arma_extra_debug_sigprint(); arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::avg_log_p(): incompatible dimensions" ); arma_debug_check( (gaus_id >= means.n_cols), "gmm_diag::avg_log_p(): specified gaussian is out of range" ); const uword N = X.n_cols; if(N == 0) { return (-Datum::inf); } #if defined(_OPENMP) { const arma_omp_state save_omp_state; const umat boundaries = internal_gen_boundaries(N); const uword n_threads = boundaries.n_cols; field< running_mean_scalar > t_running_means(n_threads); #pragma omp parallel for for(uword t=0; t < n_threads; ++t) { const uword start_index = boundaries.at(0,t); const uword end_index = boundaries.at(1,t); running_mean_scalar& current_running_mean = t_running_means[t]; for(uword i=start_index; i <= end_index; ++i) { current_running_mean( internal_scalar_log_p( X.colptr(i), gaus_id) ); } } eT avg = eT(0); for(uword t=0; t < n_threads; ++t) { running_mean_scalar& current_running_mean = t_running_means[t]; const eT w = eT(current_running_mean.count()) / eT(N); avg += w * current_running_mean.mean(); } return avg; } #else { running_mean_scalar running_mean; for(uword i=0; i template inline uword gmm_diag::internal_scalar_assign(const T1& X, const gmm_dist_mode& dist_mode) const { arma_extra_debug_sigprint(); const uword N_dims = means.n_rows; const uword N_gaus = means.n_cols; arma_debug_check( (X.n_rows != N_dims), "gmm_diag::assign(): incompatible dimensions" ); arma_debug_check( (N_gaus == 0), "gmm_diag::assign(): model has no means" ); const eT* X_mem = X.colptr(0); if(dist_mode == eucl_dist) { eT best_dist = Datum::inf; uword best_g = 0; for(uword g=0; g < N_gaus; ++g) { const eT tmp_dist = distance::eval(N_dims, X_mem, means.colptr(g), X_mem); if(tmp_dist <= best_dist) { best_dist = tmp_dist; best_g = g; } } return best_g; } else if(dist_mode == prob_dist) { const eT* log_hefts_mem = log_hefts.memptr(); eT best_p = -Datum::inf; uword best_g = 0; for(uword g=0; g < N_gaus; ++g) { const eT tmp_p = internal_scalar_log_p(X_mem, g) + log_hefts_mem[g]; if(tmp_p >= best_p) { best_p = tmp_p; best_g = g; } } return best_g; } else { arma_debug_check(true, "gmm_diag::assign(): unsupported distance mode"); } return uword(0); } template template inline void gmm_diag::internal_vec_assign(urowvec& out, const T1& X, const gmm_dist_mode& dist_mode) const { arma_extra_debug_sigprint(); const uword N_dims = means.n_rows; const uword N_gaus = means.n_cols; arma_debug_check( (X.n_rows != N_dims), "gmm_diag::assign(): incompatible dimensions" ); const uword X_n_cols = (N_gaus > 0) ? X.n_cols : 0; out.set_size(1,X_n_cols); uword* out_mem = out.memptr(); if(dist_mode == eucl_dist) { for(uword i=0; i::inf; uword best_g = 0; for(uword g=0; g::eval(N_dims, X_colptr, means.colptr(g), X_colptr); if(tmp_dist <= best_dist) { best_dist = tmp_dist; best_g = g; } } out_mem[i] = best_g; } } else if(dist_mode == prob_dist) { const eT* log_hefts_mem = log_hefts.memptr(); for(uword i=0; i::inf; uword best_g = 0; for(uword g=0; g= best_p) { best_p = tmp_p; best_g = g; } } out_mem[i] = best_g; } } else { arma_debug_check(true, "gmm_diag::assign(): unsupported distance mode"); } } template inline void gmm_diag::internal_raw_hist(urowvec& hist, const Mat& X, const gmm_dist_mode& dist_mode) const { arma_extra_debug_sigprint(); const uword N_dims = means.n_rows; const uword N_gaus = means.n_cols; const uword X_n_cols = X.n_cols; hist.zeros(N_gaus); if(N_gaus == 0) { return; } uword* hist_mem = hist.memptr(); if(dist_mode == eucl_dist) { for(uword i=0; i::inf; uword best_g = 0; for(uword g=0; g < N_gaus; ++g) { const eT tmp_dist = distance::eval(N_dims, X_colptr, means.colptr(g), X_colptr); if(tmp_dist <= best_dist) { best_dist = tmp_dist; best_g = g; } } hist_mem[best_g]++; } } else if(dist_mode == prob_dist) { const eT* log_hefts_mem = log_hefts.memptr(); for(uword i=0; i::inf; uword best_g = 0; for(uword g=0; g < N_gaus; ++g) { const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem[g]; if(tmp_p >= best_p) { best_p = tmp_p; best_g = g; } } hist_mem[best_g]++; } } } template template inline void gmm_diag::generate_initial_means(const Mat& X, const gmm_seed_mode& seed_mode) { const uword N_dims = means.n_rows; const uword N_gaus = means.n_cols; if( (seed_mode == static_subset) || (seed_mode == random_subset) ) { uvec initial_indices; if(seed_mode == static_subset) { initial_indices = linspace(0, X.n_cols-1, N_gaus); } else if(seed_mode == random_subset) { initial_indices = uvec(sort_index(randu(X.n_cols))).rows(0,N_gaus-1); } // not using randi() here as on some primitive systems it produces vectors with non-unique values // initial_indices.print("initial_indices:"); access::rw(means) = X.cols(initial_indices); } else if( (seed_mode == static_spread) || (seed_mode == random_spread) ) { uword start_index = 0; if(seed_mode == static_spread) { start_index = X.n_cols / 2; } else if(seed_mode == random_spread) { start_index = as_scalar(randi(1, distr_param(0,X.n_cols-1))); } access::rw(means).col(0) = X.unsafe_col(start_index); const eT* mah_aux_mem = mah_aux.memptr(); running_stat rs; for(uword g=1; g < N_gaus; ++g) { eT max_dist = eT(0); uword best_i = uword(0); for(uword i=0; i < X.n_cols; ++i) { rs.reset(); const eT* X_colptr = X.colptr(i); bool ignore_i = false; // find the average distance between sample i and the means so far for(uword h = 0; h < g; ++h) { const eT dist = distance::eval(N_dims, X_colptr, means.colptr(h), mah_aux_mem); // ignore sample already selected as a mean if(dist == eT(0)) { ignore_i = true; break; } else { rs(dist); } } if( (rs.mean() >= max_dist) && (ignore_i == false)) { max_dist = eT(rs.mean()); best_i = i; } } // set the mean to the sample that is the furthest away from the means so far access::rw(means).col(g) = X.unsafe_col(best_i); } } // get_stream_err2() << "generate_initial_means():" << '\n'; // means.print(); } template template inline void gmm_diag::generate_initial_dcovs_and_hefts(const Mat& X, const eT var_floor) { const uword N_dims = means.n_rows; const uword N_gaus = means.n_cols; field< running_stat_vec< Col > > rs(N_gaus); const eT* mah_aux_mem = mah_aux.memptr(); for(uword i=0; i::inf; uword best_g = 0; for(uword g=0; g::eval(N_dims, X_colptr, means.colptr(g), mah_aux_mem); if(dist <= min_dist) { min_dist = dist; best_g = g; } } rs(best_g)(X.unsafe_col(i)); } for(uword g=0; g= eT(2) ) { access::rw(dcovs).col(g) = rs(g).var(1); } else { access::rw(dcovs).col(g).ones(); } access::rw(hefts)(g) = (std::max)( (rs(g).count() / eT(X.n_cols)), std::numeric_limits::min() ); } em_fix_params(var_floor); } //! multi-threaded implementation of k-means, inspired by MapReduce template template inline bool gmm_diag::km_iterate(const Mat& X, const uword max_iter, const bool verbose) { arma_extra_debug_sigprint(); if(verbose) { get_stream_err2().unsetf(ios::showbase); get_stream_err2().unsetf(ios::uppercase); get_stream_err2().unsetf(ios::showpos); get_stream_err2().unsetf(ios::scientific); get_stream_err2().setf(ios::right); get_stream_err2().setf(ios::fixed); } const uword N_dims = means.n_rows; const uword N_gaus = means.n_cols; Mat old_means = means; Mat new_means = means; running_mean_scalar rs_delta; field< running_mean_vec > running_means(N_gaus); const eT* mah_aux_mem = mah_aux.memptr(); #if defined(_OPENMP) const arma_omp_state save_omp_state; const umat boundaries = internal_gen_boundaries(X.n_cols); const uword n_threads = boundaries.n_cols; field< field< running_mean_vec > > t_running_means(n_threads); for(uword t=0; t < n_threads; ++t) { t_running_means[t].set_size(N_gaus); } vec tmp_mean(N_dims); if(verbose) { get_stream_err2() << "gmm_diag::learn(): k-means: n_threads: " << n_threads << '\n'; } #endif for(uword iter=1; iter <= max_iter; ++iter) { #if defined(_OPENMP) { for(uword t=0; t < n_threads; ++t) { for(uword g=0; g < N_gaus; ++g) { t_running_means[t][g].reset(); } } // km_update_stats() is the "map" operation, which produces partial means #pragma omp parallel for for(uword t=0; t < n_threads; ++t) { field< running_mean_vec >& current_running_means = t_running_means[t]; km_update_stats(X, boundaries.at(0,t), boundaries.at(1,t), old_means, current_running_means); } // the "reduce" operation, which combines the partial means produced by the separate threads; // takes into account the counts for each mean for(uword g=0; g < N_gaus; ++g) { uword total_count = 0; for(uword t=0; t < n_threads; ++t) { total_count += t_running_means[t][g].count(); } tmp_mean.zeros(); bool dead = true; uword last_index = 0; if(total_count > 0) { for(uword t=0; t < n_threads; ++t) { const eT w = eT(t_running_means[t][g].count()) / eT(total_count); if(w > eT(0)) { tmp_mean += w * t_running_means[t][g].mean(); dead = false; last_index = t_running_means[t][g].last_index(); } } } running_means[g].reset(); if(dead == false) { running_means[g](tmp_mean, last_index); } } } #else { for(uword g=0; g < N_gaus; ++g) { running_means[g].reset(); } km_update_stats(X, 0, X.n_cols-1, old_means, running_means); } #endif uword n_dead_means = 0; for(uword g=0; g < N_gaus; ++g) { if(running_means[g].count() > 0) { new_means.col(g) = running_means[g].mean(); } else { n_dead_means++; } } // heuristics to resurrect dead means if(n_dead_means > 0) { if(verbose) { get_stream_err2() << "gmm_diag::learn(): k-means: recovering from dead means\n"; } if(n_dead_means == 1) { uword dead_g = 0; uword populous_g = 0; uword populous_count = running_means(0).count(); for(uword g=1; g < N_gaus; ++g) { const uword count = running_means(g).count(); if(count == 0) { dead_g = g; } if(populous_count < count) { populous_count = count; populous_g = g; } } if( (populous_count <= 2) || (dead_g == populous_g) ) { return false; } new_means.col(dead_g) = X.unsafe_col( running_means(populous_g).last_index() ); } else { uword n_resurrected_means = 0; uword dead_g = 0; for(uword live_g = 0; live_g < N_gaus; ++live_g) { if(running_means(live_g).count() >= 2) { for(; dead_g < N_gaus; ++dead_g) { if(running_means(dead_g).count() == 0) { break; } } if(dead_g == N_gaus) { break; } new_means.col(dead_g) = X.unsafe_col( running_means(live_g).last_index() ); dead_g++; n_resurrected_means++; } } if(n_resurrected_means != n_dead_means) { if(verbose) { get_stream_err2() << "gmm_diag::learn(): k-means: WARNING: did not resurrect all dead means\n"; } } } } rs_delta.reset(); for(uword g=0; g < N_gaus; ++g) { rs_delta( distance::eval(N_dims, old_means.colptr(g), new_means.colptr(g), mah_aux_mem) ); } if(verbose) { get_stream_err2() << "gmm_diag::learn(): k-means: iteration: "; get_stream_err2().unsetf(ios::scientific); get_stream_err2().setf(ios::fixed); get_stream_err2().width(std::streamsize(4)); get_stream_err2() << iter; get_stream_err2() << " delta: "; get_stream_err2().unsetf(ios::fixed); //get_stream_err2().setf(ios::scientific); get_stream_err2() << rs_delta.mean() << '\n'; } arma::swap(old_means, new_means); if(rs_delta.mean() <= Datum::eps) { break; } } access::rw(means) = old_means; return true; } template template inline void gmm_diag::km_update_stats(const Mat& X, const uword start_index, const uword end_index, const Mat& old_means, field< running_mean_vec >& running_means) const { arma_extra_debug_sigprint(); // get_stream_err2() << "km_update_stats(): start_index: " << start_index << '\n'; // get_stream_err2() << "km_update_stats(): end_index: " << end_index << '\n'; const uword N_dims = means.n_rows; const uword N_gaus = means.n_cols; const eT* mah_aux_mem = mah_aux.memptr(); for(uword i=start_index; i <= end_index; ++i) { const eT* X_colptr = X.colptr(i); double best_dist = Datum::inf; uword best_g = 0; for(uword g=0; g < N_gaus; ++g) { const double dist = distance::eval(N_dims, X_colptr, old_means.colptr(g), mah_aux_mem); // get_stream_err2() << "g: " << g << " dist: " << dist << '\n'; // old_means.col(g).print("old_means.col(g):"); // vec tmp(old_means.colptr(g), old_means.n_rows); // tmp.print("tmp:"); if(dist <= best_dist) { best_dist = dist; best_g = g; } } // get_stream_err2() << "best_g: " << best_g << '\n'; running_means[best_g]( X.unsafe_col(i), i ); } } //! multi-threaded implementation of Expectation-Maximisation, inspired by MapReduce template inline bool gmm_diag::em_iterate(const Mat& X, const uword max_iter, const eT var_floor, const bool verbose) { arma_extra_debug_sigprint(); const uword N_dims = means.n_rows; const uword N_gaus = means.n_cols; if(verbose) { get_stream_err2().unsetf(ios::showbase); get_stream_err2().unsetf(ios::uppercase); get_stream_err2().unsetf(ios::showpos); get_stream_err2().unsetf(ios::scientific); get_stream_err2().setf(ios::right); get_stream_err2().setf(ios::fixed); } #if defined(_OPENMP) const arma_omp_state save_omp_state; #endif const umat boundaries = internal_gen_boundaries(X.n_cols); const uword n_threads = boundaries.n_cols; field< Mat > t_acc_means(n_threads); field< Mat > t_acc_dcovs(n_threads); field< Col > t_acc_norm_lhoods(n_threads); field< Col > t_gaus_log_lhoods(n_threads); Col t_progress_log_lhood(n_threads); for(uword t=0; t::inf; for(uword iter=1; iter <= max_iter; ++iter) { init_constants(); em_update_params(X, boundaries, t_acc_means, t_acc_dcovs, t_acc_norm_lhoods, t_gaus_log_lhoods, t_progress_log_lhood); em_fix_params(var_floor); const eT new_avg_log_p = mean(t_progress_log_lhood); if(verbose) { get_stream_err2() << "gmm_diag::learn(): EM: iteration: "; get_stream_err2().unsetf(ios::scientific); get_stream_err2().setf(ios::fixed); get_stream_err2().width(std::streamsize(4)); get_stream_err2() << iter; get_stream_err2() << " avg_log_p: "; get_stream_err2().unsetf(ios::fixed); //get_stream_err2().setf(ios::scientific); get_stream_err2() << new_avg_log_p << '\n'; } if(is_finite(new_avg_log_p) == false) { return false; } if(std::abs(old_avg_log_p - new_avg_log_p) <= Datum::eps) { break; } old_avg_log_p = new_avg_log_p; } if(any(vectorise(dcovs) <= eT(0))) { return false; } if(means.is_finite() == false ) { return false; } if(dcovs.is_finite() == false ) { return false; } if(hefts.is_finite() == false ) { return false; } return true; } template inline void gmm_diag::em_update_params ( const Mat& X, const umat& boundaries, field< Mat >& t_acc_means, field< Mat >& t_acc_dcovs, field< Col >& t_acc_norm_lhoods, field< Col >& t_gaus_log_lhoods, Col& t_progress_log_lhood ) { arma_extra_debug_sigprint(); const uword n_threads = boundaries.n_cols; // em_generate_acc() is the "map" operation, which produces partial accumulators for means, diagonal covariances and hefts #if defined(_OPENMP) { #pragma omp parallel for for(uword t=0; t& acc_means = t_acc_means[t]; Mat& acc_dcovs = t_acc_dcovs[t]; Col& acc_norm_lhoods = t_acc_norm_lhoods[t]; Col& gaus_log_lhoods = t_gaus_log_lhoods[t]; eT& progress_log_lhood = t_progress_log_lhood[t]; em_generate_acc(X, boundaries.at(0,t), boundaries.at(1,t), acc_means, acc_dcovs, acc_norm_lhoods, gaus_log_lhoods, progress_log_lhood); } } #else { em_generate_acc(X, boundaries.at(0,0), boundaries.at(1,0), t_acc_means[0], t_acc_dcovs[0], t_acc_norm_lhoods[0], t_gaus_log_lhoods[0], t_progress_log_lhood[0]); } #endif const uword N_dims = means.n_rows; const uword N_gaus = means.n_cols; Mat& final_acc_means = t_acc_means[0]; Mat& final_acc_dcovs = t_acc_dcovs[0]; Col& final_acc_norm_lhoods = t_acc_norm_lhoods[0]; // the "reduce" operation, which combines the partial accumulators produced by the separate threads for(uword t=1; t::min() ); hefts_mem[g] = acc_norm_lhood / eT(X.n_cols); for(uword d=0; d < N_dims; ++d) { const eT tmp = acc_mean_mem[d] / acc_norm_lhood; mean_mem[d] = tmp; dcov_mem[d] = acc_dcov_mem[d] / acc_norm_lhood - tmp*tmp; } } } template inline void gmm_diag::em_generate_acc ( const Mat& X, const uword start_index, const uword end_index, Mat& acc_means, Mat& acc_dcovs, Col& acc_norm_lhoods, Col& gaus_log_lhoods, eT& progress_log_lhood ) const { arma_extra_debug_sigprint(); progress_log_lhood = eT(0); acc_means.zeros(); acc_dcovs.zeros(); acc_norm_lhoods.zeros(); gaus_log_lhoods.zeros(); const uword N_dims = means.n_rows; const uword N_gaus = means.n_cols; const eT* log_hefts_mem = log_hefts.memptr(); eT* gaus_log_lhoods_mem = gaus_log_lhoods.memptr(); for(uword i=start_index; i <= end_index; i++) { const eT* x = X.colptr(i); for(uword g=0; g < N_gaus; ++g) { gaus_log_lhoods_mem[g] = internal_scalar_log_p(x, g) + log_hefts_mem[g]; } eT log_lhood_sum = gaus_log_lhoods_mem[0]; for(uword g=1; g < N_gaus; ++g) { log_lhood_sum = log_add_exp(log_lhood_sum, gaus_log_lhoods_mem[g]); } progress_log_lhood += log_lhood_sum; for(uword g=0; g < N_gaus; ++g) { const eT norm_lhood = std::exp(gaus_log_lhoods_mem[g] - log_lhood_sum); acc_norm_lhoods[g] += norm_lhood; eT* acc_mean_mem = acc_means.colptr(g); eT* acc_dcov_mem = acc_dcovs.colptr(g); for(uword d=0; d < N_dims; ++d) { const eT x_d = x[d]; const eT y_d = x_d * norm_lhood; acc_mean_mem[d] += y_d; acc_dcov_mem[d] += y_d * x_d; // equivalent to x_d * x_d * norm_lhood } } } progress_log_lhood /= eT((end_index - start_index) + 1); } template inline void gmm_diag::em_fix_params(const eT var_floor) { arma_extra_debug_sigprint(); const uword N_dims = means.n_rows; const uword N_gaus = means.n_cols; for(uword g=0; g < N_gaus; ++g) { eT* dcov_mem = access::rw(dcovs).colptr(g); for(uword d=0; d < N_dims; ++d) { if(dcov_mem[d] < var_floor) { dcov_mem[d] = var_floor; } } } const eT heft_sum = accu(hefts); if(heft_sum != eT(1)) { access::rw(hefts) / heft_sum; } } } //! @}