v2.0.0
Loading...
Searching...
No Matches
xdawn.cpp
Go to the documentation of this file.
1//=============================================================================================================
12
13//=============================================================================================================
14// INCLUDES
15//=============================================================================================================
16
17#include "xdawn.h"
18
19//=============================================================================================================
20// EIGEN INCLUDES
21//=============================================================================================================
22
23#include <Eigen/Dense>
24
25//=============================================================================================================
26// QT INCLUDES
27//=============================================================================================================
28
29#include <QDebug>
30#include <QHash>
31
32//=============================================================================================================
33// C++ INCLUDES
34//=============================================================================================================
35
36#include <algorithm>
37
38//=============================================================================================================
39// USED NAMESPACES
40//=============================================================================================================
41
42using namespace UTILSLIB;
43using namespace MNELIB;
44using namespace Eigen;
45
46//=============================================================================================================
47// PRIVATE HELPERS
48//=============================================================================================================
49
50namespace {
51
52MatrixXd computePatterns(const MatrixXd& filters, const MatrixXd& dataCov)
53{
54 if (filters.size() == 0 || dataCov.size() == 0) {
55 return {};
56 }
57
58 MatrixXd gram = filters.transpose() * dataCov * filters;
59 CompleteOrthogonalDecomposition<MatrixXd> cod(gram);
60 return dataCov * filters * cod.pseudoInverse();
61}
62
63} // anonymous namespace
64
65//=============================================================================================================
66// MEMBER DEFINITIONS
67//=============================================================================================================
68
69XdawnResult Xdawn::fit(const QVector<MNEEpochData>& epochs,
70 int iTargetEvent,
71 int nComponents,
72 double dReg)
73{
74 XdawnResult result;
75 result.iTargetEvent = iTargetEvent;
76
77 if (epochs.isEmpty()) {
78 qWarning() << "Xdawn::fit: empty epoch list.";
79 return result;
80 }
81
82 QVector<int> goodIdx;
83 goodIdx.reserve(epochs.size());
84 for (int i = 0; i < epochs.size(); ++i) {
85 if (!epochs[i].bReject) {
86 goodIdx.append(i);
87 }
88 }
89
90 if (goodIdx.isEmpty()) {
91 qWarning() << "Xdawn::fit: no non-rejected epochs available.";
92 return result;
93 }
94
95 const int nCh = static_cast<int>(epochs[goodIdx[0]].epoch.rows());
96 const int nSamp = static_cast<int>(epochs[goodIdx[0]].epoch.cols());
97 if (nCh == 0 || nSamp == 0) {
98 qWarning() << "Xdawn::fit: epoch matrices are empty.";
99 return result;
100 }
101
102 for (int idx : goodIdx) {
103 if (epochs[idx].epoch.rows() != nCh || epochs[idx].epoch.cols() != nSamp) {
104 qWarning() << "Xdawn::fit: epoch dimension mismatch.";
105 return result;
106 }
107 }
108
109 nComponents = std::max(1, std::min(nComponents, nCh));
110
111 QHash<int, MatrixXd> classSums;
112 QHash<int, int> classCounts;
113 MatrixXd targetSum = MatrixXd::Zero(nCh, nSamp);
114 int nTarget = 0;
115
116 for (int idx : goodIdx) {
117 const MNEEpochData& ep = epochs[idx];
118 if (!classSums.contains(ep.event)) {
119 classSums.insert(ep.event, MatrixXd::Zero(nCh, nSamp));
120 classCounts.insert(ep.event, 0);
121 }
122
123 classSums[ep.event] += ep.epoch;
124 classCounts[ep.event] += 1;
125
126 if (ep.event == iTargetEvent) {
127 targetSum += ep.epoch;
128 ++nTarget;
129 }
130 }
131
132 if (nTarget == 0) {
133 qWarning() << "Xdawn::fit: no target epochs found for event" << iTargetEvent;
134 return result;
135 }
136
137 result.matTargetEvoked = targetSum / static_cast<double>(nTarget);
138
139 MatrixXd noiseCov = MatrixXd::Zero(nCh, nCh);
140 MatrixXd dataCov = MatrixXd::Zero(nCh, nCh);
141 long long nNoiseSamples = 0;
142 long long nDataSamples = 0;
143
144 QHash<int, MatrixXd> classMeans;
145 for (auto it = classSums.constBegin(); it != classSums.constEnd(); ++it) {
146 classMeans.insert(it.key(), it.value() / static_cast<double>(classCounts.value(it.key())));
147 }
148
149 for (int idx : goodIdx) {
150 const MNEEpochData& ep = epochs[idx];
151 const MatrixXd residual = ep.epoch - classMeans.value(ep.event);
152
153 dataCov += ep.epoch * ep.epoch.transpose();
154 noiseCov += residual * residual.transpose();
155 nDataSamples += nSamp;
156 nNoiseSamples += nSamp;
157 }
158
159 if (nNoiseSamples <= 0 || nDataSamples <= 0) {
160 qWarning() << "Xdawn::fit: failed to accumulate covariance samples.";
161 return result;
162 }
163
164 result.matSignalCov = result.matTargetEvoked * result.matTargetEvoked.transpose() / static_cast<double>(nSamp);
165 result.matNoiseCov = noiseCov / static_cast<double>(nNoiseSamples);
166 dataCov = dataCov / static_cast<double>(nDataSamples);
167
168 const double traceNoise = result.matNoiseCov.trace();
169 const double regValue = std::max(dReg, 0.0) * ((traceNoise > 0.0) ? traceNoise / static_cast<double>(nCh) : 1.0);
170 MatrixXd regNoiseCov = result.matNoiseCov;
171 regNoiseCov.diagonal().array() += regValue;
172
173 SelfAdjointEigenSolver<MatrixXd> noiseEig(regNoiseCov);
174 if (noiseEig.info() != Success) {
175 qWarning() << "Xdawn::fit: noise covariance eigendecomposition failed.";
176 return result;
177 }
178
179 VectorXd noiseVals = noiseEig.eigenvalues().cwiseMax(1e-12);
180 MatrixXd noiseVecs = noiseEig.eigenvectors();
181 MatrixXd invSqrtNoise = noiseVecs * noiseVals.cwiseInverse().cwiseSqrt().asDiagonal() * noiseVecs.transpose();
182
183 MatrixXd whitenedSignal = invSqrtNoise * result.matSignalCov * invSqrtNoise;
184 SelfAdjointEigenSolver<MatrixXd> signalEig(whitenedSignal);
185 if (signalEig.info() != Success) {
186 qWarning() << "Xdawn::fit: signal covariance eigendecomposition failed.";
187 return result;
188 }
189
190 result.matFilters.resize(nCh, nComponents);
191 const MatrixXd signalVecsAsc = signalEig.eigenvectors().rightCols(nComponents);
192 MatrixXd signalVecs(signalVecsAsc.rows(), signalVecsAsc.cols());
193 for (int i = 0; i < nComponents; ++i) {
194 signalVecs.col(i) = signalVecsAsc.col(nComponents - 1 - i);
195 }
196
197 result.matFilters = invSqrtNoise * signalVecs;
198
199 for (int col = 0; col < result.matFilters.cols(); ++col) {
200 const double noiseNorm = std::sqrt(result.matFilters.col(col).transpose()
201 * regNoiseCov
202 * result.matFilters.col(col));
203 if (noiseNorm > 1e-12) {
204 result.matFilters.col(col) /= noiseNorm;
205 }
206 }
207
208 result.matPatterns = computePatterns(result.matFilters, dataCov);
209 result.bValid = true;
210 return result;
211}
212
213//=============================================================================================================
214
215MatrixXd Xdawn::apply(const MatrixXd& matEpoch, const XdawnResult& result)
216{
217 if (!result.bValid || result.matFilters.size() == 0) {
218 return {};
219 }
220
221 if (matEpoch.rows() != result.matFilters.rows()) {
222 qWarning() << "Xdawn::apply: channel count mismatch.";
223 return {};
224 }
225
226 return result.matFilters.transpose() * matEpoch;
227}
228
229//=============================================================================================================
230
231MatrixXd Xdawn::denoise(const MatrixXd& matEpoch, const XdawnResult& result, int nComponents)
232{
233 if (!result.bValid || result.matFilters.size() == 0 || result.matPatterns.size() == 0) {
234 return matEpoch;
235 }
236
237 if (matEpoch.rows() != result.matFilters.rows()) {
238 qWarning() << "Xdawn::denoise: channel count mismatch.";
239 return {};
240 }
241
242 if (nComponents <= 0 || nComponents > result.matFilters.cols()) {
243 nComponents = result.matFilters.cols();
244 }
245
246 MatrixXd filters = result.matFilters.leftCols(nComponents);
247 MatrixXd patterns = result.matPatterns.leftCols(nComponents);
248
249 return patterns * (filters.transpose() * matEpoch);
250}
251
252//=============================================================================================================
253
254QVector<MNEEpochData> Xdawn::denoiseEpochs(const QVector<MNEEpochData>& epochs,
255 const XdawnResult& result,
256 int nComponents)
257{
258 QVector<MNEEpochData> out = epochs;
259 for (int i = 0; i < out.size(); ++i) {
260 out[i].epoch = denoise(out[i].epoch, result, nComponents);
261 }
262 return out;
263}
Declaration of the Xdawn class for event-related response enhancement.
Core MNE data structures (source spaces, source estimates, hemispheres).
Shared utilities (I/O helpers, spectral analysis, layout management, warp algorithms).
Result of an xDAWN decomposition.
Definition xdawn.h:53
Eigen::MatrixXd matSignalCov
Definition xdawn.h:56
Eigen::MatrixXd matPatterns
Definition xdawn.h:55
Eigen::MatrixXd matNoiseCov
Definition xdawn.h:57
Eigen::MatrixXd matFilters
Definition xdawn.h:54
Eigen::MatrixXd matTargetEvoked
Definition xdawn.h:58
static Eigen::MatrixXd apply(const Eigen::MatrixXd &matEpoch, const XdawnResult &result)
Definition xdawn.cpp:215
static QVector< MNELIB::MNEEpochData > denoiseEpochs(const QVector< MNELIB::MNEEpochData > &epochs, const XdawnResult &result, int nComponents=-1)
Definition xdawn.cpp:254
static XdawnResult fit(const QVector< MNELIB::MNEEpochData > &epochs, int iTargetEvent=1, int nComponents=2, double dReg=1e-6)
Definition xdawn.cpp:69
static Eigen::MatrixXd denoise(const Eigen::MatrixXd &matEpoch, const XdawnResult &result, int nComponents=-1)
Definition xdawn.cpp:231
Single epoch (trial slice) of sensor data with timing and rejection metadata.
Eigen::MatrixXd epoch
FIFFLIB::fiff_int_t event