45#include <Eigen/Eigenvalues>
52#include <QCoreApplication>
78 const MatrixXd& matEvoked,
79 const MatrixXd& matGain,
80 const MatrixXd& matNoiseCov,
81 const MatrixXd& matSrcCov,
86 int nChannels = matGain.rows();
87 int nSources = matGain.cols();
88 int nTimes = matEvoked.cols();
91 qInfo() <<
"[InvCMNE] Step 1/4: Computing dSPM kernel"
92 <<
"(" << nChannels <<
"ch x" << nSources <<
"src, lambda2="
94 MatrixXd matKernelDspm = computeDspmKernel(matGain, matNoiseCov, matSrcCov, settings.
lambda2);
96 qInfo() <<
"[InvCMNE] Step 1/4: dSPM kernel done"
97 <<
"(" << matKernelDspm.rows() <<
"x" << matKernelDspm.cols() <<
").";
100 qInfo() <<
"[InvCMNE] Step 2/4: Projecting evoked data to source space"
101 <<
"(" << nTimes <<
"time points) …";
102 MatrixXd matDspmData = matKernelDspm * matEvoked;
105 VectorXi vertices = VectorXi::LinSpaced(matDspmData.rows(), 0, matDspmData.rows() - 1);
107 qInfo() <<
"[InvCMNE] Step 2/4: dSPM source estimate done"
108 <<
"(" << matDspmData.rows() <<
"sources x" << matDspmData.cols() <<
"samples).";
111 qInfo() <<
"[InvCMNE] Step 3/4: Z-score rectifying source data …";
112 MatrixXd matZScored = zScoreRectify(matDspmData);
113 qInfo() <<
"[InvCMNE] Step 3/4: Z-score rectification done.";
116 MatrixXd matCmneData;
119 qInfo() <<
"[InvCMNE] Step 4/4: Applying LSTM temporal correction"
120 <<
"(look-back=" << settings.
lookBack <<
","
121 << (nTimes - settings.
lookBack) <<
"correctable time points) …";
126 qInfo() <<
"[InvCMNE] Step 4/4: LSTM correction done.";
129 matCmneData = matDspmData;
132 qInfo() <<
"[InvCMNE] Step 4/4: No ONNX model — using moving-average correction.";
134 qInfo() <<
"[InvCMNE] Step 4/4: Moving-average correction done.";
136 qInfo() <<
"[InvCMNE] Step 4/4: Not enough time points for lookBack window"
137 <<
"(need" << settings.
lookBack <<
", have" << nTimes <<
").";
149MatrixXd InvCMNE::computeDspmKernel(
150 const MatrixXd& matGain,
151 const MatrixXd& matNoiseCov,
152 const MatrixXd& matSrcCov,
155 int nChannels = matGain.rows();
156 int nSources = matGain.cols();
160 qInfo() <<
" [dSPM kernel] Eigendecomposition of noise covariance"
161 <<
"(" << nChannels <<
"x" << nChannels <<
") …";
162 SelfAdjointEigenSolver<MatrixXd> eigSolver(matNoiseCov);
163 VectorXd eigVals = eigSolver.eigenvalues();
164 MatrixXd eigVecs = eigSolver.eigenvectors();
167 double maxEig = eigVals.maxCoeff();
168 double threshold = maxEig * 1e-10;
169 VectorXd eigValsInvSqrt(nChannels);
170 for (
int i = 0; i < nChannels; ++i) {
171 eigValsInvSqrt(i) = (eigVals(i) > threshold) ? 1.0 / std::sqrt(eigVals(i)) : 0.0;
174 MatrixXd matWhitener = eigVecs * eigValsInvSqrt.asDiagonal() * eigVecs.transpose();
177 qInfo() <<
" [dSPM kernel] Whitening gain matrix …";
178 MatrixXd matGainWhitened = matWhitener * matGain;
181 qInfo() <<
" [dSPM kernel] Computing MNE kernel (LDLT solve," << nChannels <<
"x" << nChannels <<
") …";
183 MatrixXd matGCR = matGainWhitened * matSrcCov;
184 MatrixXd matA = matGCR * matGainWhitened.transpose();
185 matA.diagonal().array() += lambda2;
188 auto ldlt = matA.ldlt();
189 MatrixXd matK = (matSrcCov * matGainWhitened.transpose()) * ldlt.solve(MatrixXd::Identity(nChannels, nChannels));
194 qInfo() <<
" [dSPM kernel] Normalizing" << nSources <<
"source rows …";
195 MatrixXd matKCn = matK * matNoiseCov;
196 for (
int i = 0; i < nSources; ++i) {
197 double noiseNorm = std::sqrt(matKCn.row(i).dot(matK.row(i)));
198 if (noiseNorm > 1e-10) {
199 matK.row(i) /= noiseNorm;
208MatrixXd InvCMNE::zScoreRectify(
const MatrixXd& matStcData)
210 int nSources = matStcData.rows();
211 int nTimes = matStcData.cols();
213 MatrixXd matResult(nSources, nTimes);
215 for (
int i = 0; i < nSources; ++i) {
217 VectorXd absRow = matStcData.row(i).cwiseAbs();
220 double mu = absRow.mean();
221 double variance = (absRow.array() - mu).square().mean();
222 double sigma = std::sqrt(variance);
225 double denom = std::max(sigma, 1e-10);
226 matResult.row(i) = (absRow.array() - mu) / denom;
235 const MatrixXd& matDspmData,
236 const QString& onnxModelPath,
239 int nSources = matDspmData.rows();
240 int nTimes = matDspmData.cols();
242 MatrixXd result = matDspmData;
244 int nCorrectableSteps = nTimes - lookBack;
245 int reportInterval = qMax(1, nCorrectableSteps / 10);
251 if (!onnxModelPath.isEmpty()) {
252 if (lstmModel.
load(onnxModelPath)) {
254 qInfo() <<
" [LSTM correction] ONNX model loaded — using LSTM inference.";
256 qWarning() <<
" [LSTM correction] Failed to load ONNX model — falling back to moving average.";
259 qInfo() <<
" [LSTM correction] No ONNX model path — using moving average.";
264 std::vector<float> inputBuf;
266 inputBuf.resize(
static_cast<size_t>(lookBack) *
static_cast<size_t>(nSources));
270 for (
int t = lookBack; t < nTimes; ++t) {
271 int step = t - lookBack;
272 if (step % reportInterval == 0 || t == nTimes - 1) {
273 double pct = 100.0 * (step + 1) / nCorrectableSteps;
274 qInfo().noquote() << QString(
" [LSTM correction] %1% (%2/%3 time steps)")
275 .arg(pct, 0,
'f', 0).arg(step + 1).arg(nCorrectableSteps);
283 for (
int k = 0; k < lookBack; ++k) {
284 int col = t - lookBack + k;
285 for (
int s = 0; s < nSources; ++s) {
286 inputBuf[
static_cast<size_t>(k) *
static_cast<size_t>(nSources)
287 +
static_cast<size_t>(s)] =
static_cast<float>(result(s, col));
292 std::vector<int64_t> inputShape = {1,
static_cast<int64_t
>(lookBack),
293 static_cast<int64_t
>(nSources)};
301 prediction.resize(nSources);
302 const float* outPtr = outputTensor.
data();
303 for (
int s = 0; s < nSources; ++s) {
304 prediction(s) =
static_cast<double>(outPtr[s]);
308 MatrixXd window = result.middleCols(t - lookBack, lookBack);
309 prediction = window.rowwise().mean();
313 double maxVal = prediction.cwiseAbs().maxCoeff();
314 if (maxVal > 1e-10) {
315 prediction = prediction.cwiseAbs() / maxVal;
319 result.col(t) = prediction.cwiseProduct(matDspmData.col(t));
330 const QString& fwdPath,
331 const QString& covPath,
332 const QString& epochsPath,
333 const QString& outOnnxPath,
335 const QString& gtStcPrefix,
341 const QString& finetuneOnnxPath,
342 const QString& pythonExe)
346 QString appDir = QCoreApplication::applicationDirPath();
347 QString cmneDir = QDir(appDir).absoluteFilePath(
348 QStringLiteral(
"../scripts/ml/training/cmne"));
351 if (!QFile::exists(QDir(cmneDir).absoluteFilePath(QStringLiteral(
"pyproject.toml")))) {
352 cmneDir = QStringLiteral(
"scripts/ml/training/cmne");
355 QString scriptPath = QDir(cmneDir).absoluteFilePath(QStringLiteral(
"train_cmne_lstm.py"));
357 if (!QFile::exists(scriptPath)) {
359 result.
stdErr = QStringLiteral(
"Training script not found: ") + scriptPath;
360 qWarning() <<
"[InvCMNE::trainLstm]" << result.
stdErr;
364 qDebug() <<
"[InvCMNE::trainLstm] Script:" << scriptPath;
365 qDebug() <<
"[InvCMNE::trainLstm] Package dir:" << cmneDir;
369 switch (settings.
method) {
370 case 0: methodStr = QStringLiteral(
"MNE");
break;
371 case 1: methodStr = QStringLiteral(
"dSPM");
break;
372 case 2: methodStr = QStringLiteral(
"sLORETA");
break;
373 case 3: methodStr = QStringLiteral(
"eLORETA");
break;
374 default: methodStr = QStringLiteral(
"dSPM");
break;
377 double snr = 1.0 / std::sqrt(settings.
lambda2);
381 args << QStringLiteral(
"--fwd") << fwdPath
382 << QStringLiteral(
"--cov") << covPath
383 << QStringLiteral(
"--epochs") << epochsPath
384 << QStringLiteral(
"--out") << outOnnxPath
385 << QStringLiteral(
"--look-back") << QString::number(settings.
lookBack)
386 << QStringLiteral(
"--method") << methodStr
387 << QStringLiteral(
"--snr") << QString::number(snr,
'g', 6)
388 << QStringLiteral(
"--hidden") << QString::number(hiddenSize)
389 << QStringLiteral(
"--layers") << QString::number(numLayers)
390 << QStringLiteral(
"--train-epochs") << QString::number(trainEpochs)
391 << QStringLiteral(
"--lr") << QString::number(learningRate,
'g', 6)
392 << QStringLiteral(
"--batch") << QString::number(batchSize);
394 if (!gtStcPrefix.isEmpty()) {
395 args << QStringLiteral(
"--gt-stc") << gtStcPrefix;
398 if (!finetuneOnnxPath.isEmpty()) {
399 args << QStringLiteral(
"--finetune") << finetuneOnnxPath;
406 config.
venvDir = QDir(cmneDir).absoluteFilePath(QStringLiteral(
".venv"));
411 return trainer.
run(scriptPath, args);
InvCMNE class declaration (Contextual MNE, Dinh et al. 2021).
MLTrainer class declaration — ML-specific convenience wrapper over PythonRunner.
MlOnnxModel class declaration.
MlTensor class declaration — N-dimensional, row-major, zero-copy.
Inverse source estimation (MNE, dSPM, sLORETA, dipole fitting).
InvSourceEstimate stcDspm
InvSourceEstimate stcCmne
Eigen::MatrixXd matKernelDspm
InvSourceEstimate stcLstmPredict
static InvCMNEResult compute(const Eigen::MatrixXd &matEvoked, const Eigen::MatrixXd &matGain, const Eigen::MatrixXd &matNoiseCov, const Eigen::MatrixXd &matSrcCov, const InvCMNESettings &settings)
static Eigen::MatrixXd applyLstmCorrection(const Eigen::MatrixXd &matDspmData, const QString &onnxModelPath, int lookBack)
static UTILSLIB::PythonRunnerResult trainLstm(const QString &fwdPath, const QString &covPath, const QString &epochsPath, const QString &outOnnxPath, const InvCMNESettings &settings, const QString >StcPrefix={}, int hiddenSize=256, int numLayers=1, int trainEpochs=50, double learningRate=1e-3, int batchSize=64, const QString &finetuneOnnxPath={}, const QString &pythonExe=QStringLiteral("python3"))
ONNX Runtime backed model.
bool load(const QString &path) override
MlTensor predict(const MlTensor &input) const override
N-dimensional tensor with contiguous row-major (C-order) float32 storage.
static MlTensor view(float *data, std::vector< int64_t > shape)
ML training script launcher.
UTILSLIB::PythonRunnerResult run(const QString &scriptPath, const QStringList &args={})
Script execution result container.
Script execution configuration.