v2.0.0
Loading...
Searching...
No Matches
xdawn.cpp
Go to the documentation of this file.
1//=============================================================================================================
33
34//=============================================================================================================
35// INCLUDES
36//=============================================================================================================
37
38#include "xdawn.h"
39
40//=============================================================================================================
41// EIGEN INCLUDES
42//=============================================================================================================
43
44#include <Eigen/Dense>
45
46//=============================================================================================================
47// QT INCLUDES
48//=============================================================================================================
49
50#include <QDebug>
51#include <QHash>
52
53//=============================================================================================================
54// C++ INCLUDES
55//=============================================================================================================
56
57#include <algorithm>
58
59//=============================================================================================================
60// USED NAMESPACES
61//=============================================================================================================
62
63using namespace UTILSLIB;
64using namespace MNELIB;
65using namespace Eigen;
66
67//=============================================================================================================
68// PRIVATE HELPERS
69//=============================================================================================================
70
71namespace {
72
73MatrixXd computePatterns(const MatrixXd& filters, const MatrixXd& dataCov)
74{
75 if (filters.size() == 0 || dataCov.size() == 0) {
76 return {};
77 }
78
79 MatrixXd gram = filters.transpose() * dataCov * filters;
80 CompleteOrthogonalDecomposition<MatrixXd> cod(gram);
81 return dataCov * filters * cod.pseudoInverse();
82}
83
84} // anonymous namespace
85
86//=============================================================================================================
87// MEMBER DEFINITIONS
88//=============================================================================================================
89
90XdawnResult Xdawn::fit(const QVector<MNEEpochData>& epochs,
91 int iTargetEvent,
92 int nComponents,
93 double dReg)
94{
95 XdawnResult result;
96 result.iTargetEvent = iTargetEvent;
97
98 if (epochs.isEmpty()) {
99 qWarning() << "Xdawn::fit: empty epoch list.";
100 return result;
101 }
102
103 QVector<int> goodIdx;
104 goodIdx.reserve(epochs.size());
105 for (int i = 0; i < epochs.size(); ++i) {
106 if (!epochs[i].bReject) {
107 goodIdx.append(i);
108 }
109 }
110
111 if (goodIdx.isEmpty()) {
112 qWarning() << "Xdawn::fit: no non-rejected epochs available.";
113 return result;
114 }
115
116 const int nCh = static_cast<int>(epochs[goodIdx[0]].epoch.rows());
117 const int nSamp = static_cast<int>(epochs[goodIdx[0]].epoch.cols());
118 if (nCh == 0 || nSamp == 0) {
119 qWarning() << "Xdawn::fit: epoch matrices are empty.";
120 return result;
121 }
122
123 for (int idx : goodIdx) {
124 if (epochs[idx].epoch.rows() != nCh || epochs[idx].epoch.cols() != nSamp) {
125 qWarning() << "Xdawn::fit: epoch dimension mismatch.";
126 return result;
127 }
128 }
129
130 nComponents = std::max(1, std::min(nComponents, nCh));
131
132 QHash<int, MatrixXd> classSums;
133 QHash<int, int> classCounts;
134 MatrixXd targetSum = MatrixXd::Zero(nCh, nSamp);
135 int nTarget = 0;
136
137 for (int idx : goodIdx) {
138 const MNEEpochData& ep = epochs[idx];
139 if (!classSums.contains(ep.event)) {
140 classSums.insert(ep.event, MatrixXd::Zero(nCh, nSamp));
141 classCounts.insert(ep.event, 0);
142 }
143
144 classSums[ep.event] += ep.epoch;
145 classCounts[ep.event] += 1;
146
147 if (ep.event == iTargetEvent) {
148 targetSum += ep.epoch;
149 ++nTarget;
150 }
151 }
152
153 if (nTarget == 0) {
154 qWarning() << "Xdawn::fit: no target epochs found for event" << iTargetEvent;
155 return result;
156 }
157
158 result.matTargetEvoked = targetSum / static_cast<double>(nTarget);
159
160 MatrixXd noiseCov = MatrixXd::Zero(nCh, nCh);
161 MatrixXd dataCov = MatrixXd::Zero(nCh, nCh);
162 long long nNoiseSamples = 0;
163 long long nDataSamples = 0;
164
165 QHash<int, MatrixXd> classMeans;
166 for (auto it = classSums.constBegin(); it != classSums.constEnd(); ++it) {
167 classMeans.insert(it.key(), it.value() / static_cast<double>(classCounts.value(it.key())));
168 }
169
170 for (int idx : goodIdx) {
171 const MNEEpochData& ep = epochs[idx];
172 const MatrixXd residual = ep.epoch - classMeans.value(ep.event);
173
174 dataCov += ep.epoch * ep.epoch.transpose();
175 noiseCov += residual * residual.transpose();
176 nDataSamples += nSamp;
177 nNoiseSamples += nSamp;
178 }
179
180 if (nNoiseSamples <= 0 || nDataSamples <= 0) {
181 qWarning() << "Xdawn::fit: failed to accumulate covariance samples.";
182 return result;
183 }
184
185 result.matSignalCov = result.matTargetEvoked * result.matTargetEvoked.transpose() / static_cast<double>(nSamp);
186 result.matNoiseCov = noiseCov / static_cast<double>(nNoiseSamples);
187 dataCov = dataCov / static_cast<double>(nDataSamples);
188
189 const double traceNoise = result.matNoiseCov.trace();
190 const double regValue = std::max(dReg, 0.0) * ((traceNoise > 0.0) ? traceNoise / static_cast<double>(nCh) : 1.0);
191 MatrixXd regNoiseCov = result.matNoiseCov;
192 regNoiseCov.diagonal().array() += regValue;
193
194 SelfAdjointEigenSolver<MatrixXd> noiseEig(regNoiseCov);
195 if (noiseEig.info() != Success) {
196 qWarning() << "Xdawn::fit: noise covariance eigendecomposition failed.";
197 return result;
198 }
199
200 VectorXd noiseVals = noiseEig.eigenvalues().cwiseMax(1e-12);
201 MatrixXd noiseVecs = noiseEig.eigenvectors();
202 MatrixXd invSqrtNoise = noiseVecs * noiseVals.cwiseInverse().cwiseSqrt().asDiagonal() * noiseVecs.transpose();
203
204 MatrixXd whitenedSignal = invSqrtNoise * result.matSignalCov * invSqrtNoise;
205 SelfAdjointEigenSolver<MatrixXd> signalEig(whitenedSignal);
206 if (signalEig.info() != Success) {
207 qWarning() << "Xdawn::fit: signal covariance eigendecomposition failed.";
208 return result;
209 }
210
211 result.matFilters.resize(nCh, nComponents);
212 const MatrixXd signalVecsAsc = signalEig.eigenvectors().rightCols(nComponents);
213 MatrixXd signalVecs(signalVecsAsc.rows(), signalVecsAsc.cols());
214 for (int i = 0; i < nComponents; ++i) {
215 signalVecs.col(i) = signalVecsAsc.col(nComponents - 1 - i);
216 }
217
218 result.matFilters = invSqrtNoise * signalVecs;
219
220 for (int col = 0; col < result.matFilters.cols(); ++col) {
221 const double noiseNorm = std::sqrt(result.matFilters.col(col).transpose()
222 * regNoiseCov
223 * result.matFilters.col(col));
224 if (noiseNorm > 1e-12) {
225 result.matFilters.col(col) /= noiseNorm;
226 }
227 }
228
229 result.matPatterns = computePatterns(result.matFilters, dataCov);
230 result.bValid = true;
231 return result;
232}
233
234//=============================================================================================================
235
236MatrixXd Xdawn::apply(const MatrixXd& matEpoch, const XdawnResult& result)
237{
238 if (!result.bValid || result.matFilters.size() == 0) {
239 return {};
240 }
241
242 if (matEpoch.rows() != result.matFilters.rows()) {
243 qWarning() << "Xdawn::apply: channel count mismatch.";
244 return {};
245 }
246
247 return result.matFilters.transpose() * matEpoch;
248}
249
250//=============================================================================================================
251
252MatrixXd Xdawn::denoise(const MatrixXd& matEpoch, const XdawnResult& result, int nComponents)
253{
254 if (!result.bValid || result.matFilters.size() == 0 || result.matPatterns.size() == 0) {
255 return matEpoch;
256 }
257
258 if (matEpoch.rows() != result.matFilters.rows()) {
259 qWarning() << "Xdawn::denoise: channel count mismatch.";
260 return {};
261 }
262
263 if (nComponents <= 0 || nComponents > result.matFilters.cols()) {
264 nComponents = result.matFilters.cols();
265 }
266
267 MatrixXd filters = result.matFilters.leftCols(nComponents);
268 MatrixXd patterns = result.matPatterns.leftCols(nComponents);
269
270 return patterns * (filters.transpose() * matEpoch);
271}
272
273//=============================================================================================================
274
275QVector<MNEEpochData> Xdawn::denoiseEpochs(const QVector<MNEEpochData>& epochs,
276 const XdawnResult& result,
277 int nComponents)
278{
279 QVector<MNEEpochData> out = epochs;
280 for (int i = 0; i < out.size(); ++i) {
281 out[i].epoch = denoise(out[i].epoch, result, nComponents);
282 }
283 return out;
284}
Declaration of the Xdawn class for event-related response enhancement.
Core MNE data structures (source spaces, source estimates, hemispheres).
Shared utilities (I/O helpers, spectral analysis, layout management, warp algorithms).
Result of an xDAWN decomposition.
Definition xdawn.h:74
Eigen::MatrixXd matSignalCov
Definition xdawn.h:77
Eigen::MatrixXd matPatterns
Definition xdawn.h:76
Eigen::MatrixXd matNoiseCov
Definition xdawn.h:78
Eigen::MatrixXd matFilters
Definition xdawn.h:75
Eigen::MatrixXd matTargetEvoked
Definition xdawn.h:79
static Eigen::MatrixXd apply(const Eigen::MatrixXd &matEpoch, const XdawnResult &result)
Definition xdawn.cpp:236
static QVector< MNELIB::MNEEpochData > denoiseEpochs(const QVector< MNELIB::MNEEpochData > &epochs, const XdawnResult &result, int nComponents=-1)
Definition xdawn.cpp:275
static XdawnResult fit(const QVector< MNELIB::MNEEpochData > &epochs, int iTargetEvent=1, int nComponents=2, double dReg=1e-6)
Definition xdawn.cpp:90
static Eigen::MatrixXd denoise(const Eigen::MatrixXd &matEpoch, const XdawnResult &result, int nComponents=-1)
Definition xdawn.cpp:252
Eigen::MatrixXd epoch
FIFFLIB::fiff_int_t event