98 if (epochs.isEmpty()) {
99 qWarning() <<
"Xdawn::fit: empty epoch list.";
103 QVector<int> goodIdx;
104 goodIdx.reserve(epochs.size());
105 for (
int i = 0; i < epochs.size(); ++i) {
106 if (!epochs[i].bReject) {
111 if (goodIdx.isEmpty()) {
112 qWarning() <<
"Xdawn::fit: no non-rejected epochs available.";
116 const int nCh =
static_cast<int>(epochs[goodIdx[0]].epoch.rows());
117 const int nSamp =
static_cast<int>(epochs[goodIdx[0]].epoch.cols());
118 if (nCh == 0 || nSamp == 0) {
119 qWarning() <<
"Xdawn::fit: epoch matrices are empty.";
123 for (
int idx : goodIdx) {
124 if (epochs[idx].epoch.rows() != nCh || epochs[idx].epoch.cols() != nSamp) {
125 qWarning() <<
"Xdawn::fit: epoch dimension mismatch.";
130 nComponents = std::max(1, std::min(nComponents, nCh));
132 QHash<int, MatrixXd> classSums;
133 QHash<int, int> classCounts;
134 MatrixXd targetSum = MatrixXd::Zero(nCh, nSamp);
137 for (
int idx : goodIdx) {
139 if (!classSums.contains(ep.
event)) {
140 classSums.insert(ep.
event, MatrixXd::Zero(nCh, nSamp));
141 classCounts.insert(ep.
event, 0);
145 classCounts[ep.
event] += 1;
147 if (ep.
event == iTargetEvent) {
148 targetSum += ep.
epoch;
154 qWarning() <<
"Xdawn::fit: no target epochs found for event" << iTargetEvent;
160 MatrixXd noiseCov = MatrixXd::Zero(nCh, nCh);
161 MatrixXd dataCov = MatrixXd::Zero(nCh, nCh);
162 long long nNoiseSamples = 0;
163 long long nDataSamples = 0;
165 QHash<int, MatrixXd> classMeans;
166 for (
auto it = classSums.constBegin(); it != classSums.constEnd(); ++it) {
167 classMeans.insert(it.key(), it.value() /
static_cast<double>(classCounts.value(it.key())));
170 for (
int idx : goodIdx) {
172 const MatrixXd residual = ep.
epoch - classMeans.value(ep.
event);
175 noiseCov += residual * residual.transpose();
176 nDataSamples += nSamp;
177 nNoiseSamples += nSamp;
180 if (nNoiseSamples <= 0 || nDataSamples <= 0) {
181 qWarning() <<
"Xdawn::fit: failed to accumulate covariance samples.";
186 result.
matNoiseCov = noiseCov /
static_cast<double>(nNoiseSamples);
187 dataCov = dataCov /
static_cast<double>(nDataSamples);
189 const double traceNoise = result.
matNoiseCov.trace();
190 const double regValue = std::max(dReg, 0.0) * ((traceNoise > 0.0) ? traceNoise /
static_cast<double>(nCh) : 1.0);
192 regNoiseCov.diagonal().array() += regValue;
194 SelfAdjointEigenSolver<MatrixXd> noiseEig(regNoiseCov);
195 if (noiseEig.info() != Success) {
196 qWarning() <<
"Xdawn::fit: noise covariance eigendecomposition failed.";
200 VectorXd noiseVals = noiseEig.eigenvalues().cwiseMax(1e-12);
201 MatrixXd noiseVecs = noiseEig.eigenvectors();
202 MatrixXd invSqrtNoise = noiseVecs * noiseVals.cwiseInverse().cwiseSqrt().asDiagonal() * noiseVecs.transpose();
204 MatrixXd whitenedSignal = invSqrtNoise * result.
matSignalCov * invSqrtNoise;
205 SelfAdjointEigenSolver<MatrixXd> signalEig(whitenedSignal);
206 if (signalEig.info() != Success) {
207 qWarning() <<
"Xdawn::fit: signal covariance eigendecomposition failed.";
212 const MatrixXd signalVecsAsc = signalEig.eigenvectors().rightCols(nComponents);
213 MatrixXd signalVecs(signalVecsAsc.rows(), signalVecsAsc.cols());
214 for (
int i = 0; i < nComponents; ++i) {
215 signalVecs.col(i) = signalVecsAsc.col(nComponents - 1 - i);
218 result.
matFilters = invSqrtNoise * signalVecs;
220 for (
int col = 0; col < result.
matFilters.cols(); ++col) {
221 const double noiseNorm = std::sqrt(result.
matFilters.col(col).transpose()
224 if (noiseNorm > 1e-12) {