60 const int p =
static_cast<int>(matData.rows());
61 const int n =
static_cast<int>(matData.cols());
64 const MatrixXd S = (matData * matData.transpose()) /
static_cast<double>(n);
67 const double mu = S.trace() /
static_cast<double>(p);
70 MatrixXd centered = S - mu * MatrixXd::Identity(p, p);
71 const double delta = centered.squaredNorm() /
static_cast<double>(p);
80 double sumSqNorms = 0.0;
81 for (
int k = 0; k < n; ++k) {
82 const double nrm = matData.col(k).squaredNorm();
83 sumSqNorms += nrm * nrm;
85 const double betaSum = sumSqNorms -
static_cast<double>(n) * S.squaredNorm();
86 const double beta = betaSum / (
static_cast<double>(n) *
static_cast<double>(n) *
static_cast<double>(p));
89 const double alpha = std::clamp(beta / delta, 0.0, 1.0);
92 MatrixXd covShrunk = (1.0 - alpha) * S;
93 covShrunk.diagonal().array() += alpha * mu;
95 return {covShrunk, alpha};