v2.0.0
Loading...
Searching...
No Matches
inv_cmne.cpp
Go to the documentation of this file.
1//=============================================================================================================
34
35//=============================================================================================================
36// INCLUDES
37//=============================================================================================================
38
39#include "inv_cmne.h"
40
41//=============================================================================================================
42// EIGEN INCLUDES
43//=============================================================================================================
44
45#include <Eigen/Eigenvalues>
46
47//=============================================================================================================
48// QT INCLUDES
49//=============================================================================================================
50
51#include <QDebug>
52#include <QCoreApplication>
53#include <QDir>
54
55//=============================================================================================================
56// MNE-CPP INCLUDES
57//=============================================================================================================
58
59#include <ml/ml_onnx_model.h>
60#include <ml/ml_tensor.h>
61
62#ifndef WASMBUILD
63#include <ml/ml_trainer.h>
64#endif
65
66//=============================================================================================================
67// USED NAMESPACES
68//=============================================================================================================
69
70using namespace INVLIB;
71using namespace Eigen;
72
73//=============================================================================================================
74// DEFINE MEMBER METHODS
75//=============================================================================================================
76
78 const MatrixXd& matEvoked,
79 const MatrixXd& matGain,
80 const MatrixXd& matNoiseCov,
81 const MatrixXd& matSrcCov,
82 const InvCMNESettings& settings)
83{
84 InvCMNEResult result;
85
86 int nChannels = matGain.rows();
87 int nSources = matGain.cols();
88 int nTimes = matEvoked.cols();
89
90 // Step 1: Compute dSPM kernel
91 qInfo() << "[InvCMNE] Step 1/4: Computing dSPM kernel"
92 << "(" << nChannels << "ch x" << nSources << "src, lambda2="
93 << settings.lambda2 << ") …";
94 MatrixXd matKernelDspm = computeDspmKernel(matGain, matNoiseCov, matSrcCov, settings.lambda2);
95 result.matKernelDspm = matKernelDspm;
96 qInfo() << "[InvCMNE] Step 1/4: dSPM kernel done"
97 << "(" << matKernelDspm.rows() << "x" << matKernelDspm.cols() << ").";
98
99 // Step 2: Apply kernel to evoked data -> dSPM source estimate
100 qInfo() << "[InvCMNE] Step 2/4: Projecting evoked data to source space"
101 << "(" << nTimes << "time points) …";
102 MatrixXd matDspmData = matKernelDspm * matEvoked; // n_sources x n_times
103
104 // Build dSPM source estimate
105 VectorXi vertices = VectorXi::LinSpaced(matDspmData.rows(), 0, matDspmData.rows() - 1);
106 result.stcDspm = InvSourceEstimate(matDspmData, vertices, 0.0f, 1.0f);
107 qInfo() << "[InvCMNE] Step 2/4: dSPM source estimate done"
108 << "(" << matDspmData.rows() << "sources x" << matDspmData.cols() << "samples).";
109
110 // Step 3: Z-score rectify
111 qInfo() << "[InvCMNE] Step 3/4: Z-score rectifying source data …";
112 MatrixXd matZScored = zScoreRectify(matDspmData);
113 qInfo() << "[InvCMNE] Step 3/4: Z-score rectification done.";
114
115 // Step 4: Apply LSTM correction if model available and enough time points
116 MatrixXd matCmneData;
117
118 if (!settings.onnxModelPath.isEmpty() && nTimes >= settings.lookBack) {
119 qInfo() << "[InvCMNE] Step 4/4: Applying LSTM temporal correction"
120 << "(look-back=" << settings.lookBack << ","
121 << (nTimes - settings.lookBack) << "correctable time points) …";
122 matCmneData = applyLstmCorrection(matZScored, settings.onnxModelPath, settings.lookBack);
123
124 // Store raw LSTM prediction for diagnostics
125 result.stcLstmPredict = InvSourceEstimate(matCmneData, vertices, 0.0f, 1.0f);
126 qInfo() << "[InvCMNE] Step 4/4: LSTM correction done.";
127 } else {
128 // No correction possible — CMNE falls back to dSPM
129 matCmneData = matDspmData;
130
131 if (settings.onnxModelPath.isEmpty()) {
132 qInfo() << "[InvCMNE] Step 4/4: No ONNX model — using moving-average correction.";
133 matCmneData = applyLstmCorrection(matDspmData, QString(), settings.lookBack);
134 qInfo() << "[InvCMNE] Step 4/4: Moving-average correction done.";
135 } else {
136 qInfo() << "[InvCMNE] Step 4/4: Not enough time points for lookBack window"
137 << "(need" << settings.lookBack << ", have" << nTimes << ").";
138 }
139 }
140
141 // Build CMNE source estimate
142 result.stcCmne = InvSourceEstimate(matCmneData, vertices, 0.0f, 1.0f);
143
144 return result;
145}
146
147//=============================================================================================================
148
149MatrixXd InvCMNE::computeDspmKernel(
150 const MatrixXd& matGain,
151 const MatrixXd& matNoiseCov,
152 const MatrixXd& matSrcCov,
153 double lambda2)
154{
155 int nChannels = matGain.rows();
156 int nSources = matGain.cols();
157
158 // Step 1: Whiten noise covariance via eigendecomposition
159 // C_n = V * D * V^T -> C_n^{-1/2} = V * D^{-1/2} * V^T
160 qInfo() << " [dSPM kernel] Eigendecomposition of noise covariance"
161 << "(" << nChannels << "x" << nChannels << ") …";
162 SelfAdjointEigenSolver<MatrixXd> eigSolver(matNoiseCov);
163 VectorXd eigVals = eigSolver.eigenvalues();
164 MatrixXd eigVecs = eigSolver.eigenvectors();
165
166 // Regularize: clamp small eigenvalues
167 double maxEig = eigVals.maxCoeff();
168 double threshold = maxEig * 1e-10;
169 VectorXd eigValsInvSqrt(nChannels);
170 for (int i = 0; i < nChannels; ++i) {
171 eigValsInvSqrt(i) = (eigVals(i) > threshold) ? 1.0 / std::sqrt(eigVals(i)) : 0.0;
172 }
173
174 MatrixXd matWhitener = eigVecs * eigValsInvSqrt.asDiagonal() * eigVecs.transpose();
175
176 // Step 2: Whiten gain matrix
177 qInfo() << " [dSPM kernel] Whitening gain matrix …";
178 MatrixXd matGainWhitened = matWhitener * matGain; // n_channels x n_sources
179
180 // Step 3: MNE kernel
181 qInfo() << " [dSPM kernel] Computing MNE kernel (LDLT solve," << nChannels << "x" << nChannels << ") …";
182 // K = C_R * G_tilde^T * (G_tilde * C_R * G_tilde^T + lambda2 * I)^{-1}
183 MatrixXd matGCR = matGainWhitened * matSrcCov; // n_channels x n_sources
184 MatrixXd matA = matGCR * matGainWhitened.transpose(); // n_channels x n_channels
185 matA.diagonal().array() += lambda2;
186
187 // Solve once: A^{-1} via LDLT, then K = (C_R * G_tilde^T) * A^{-1}
188 auto ldlt = matA.ldlt();
189 MatrixXd matK = (matSrcCov * matGainWhitened.transpose()) * ldlt.solve(MatrixXd::Identity(nChannels, nChannels));
190
191 // Step 4: dSPM normalization
192 // noise_norm_i = sqrt((K * C_n * K^T)(i,i))
193 // K_dSPM(i,:) = K(i,:) / noise_norm_i
194 qInfo() << " [dSPM kernel] Normalizing" << nSources << "source rows …";
195 MatrixXd matKCn = matK * matNoiseCov; // n_sources x n_channels
196 for (int i = 0; i < nSources; ++i) {
197 double noiseNorm = std::sqrt(matKCn.row(i).dot(matK.row(i)));
198 if (noiseNorm > 1e-10) {
199 matK.row(i) /= noiseNorm;
200 }
201 }
202
203 return matK; // n_sources x n_channels (dSPM kernel)
204}
205
206//=============================================================================================================
207
208MatrixXd InvCMNE::zScoreRectify(const MatrixXd& matStcData)
209{
210 int nSources = matStcData.rows();
211 int nTimes = matStcData.cols();
212
213 MatrixXd matResult(nSources, nTimes);
214
215 for (int i = 0; i < nSources; ++i) {
216 // Absolute value
217 VectorXd absRow = matStcData.row(i).cwiseAbs();
218
219 // Mean and standard deviation across time
220 double mu = absRow.mean();
221 double variance = (absRow.array() - mu).square().mean();
222 double sigma = std::sqrt(variance);
223
224 // Z-score (guard against zero std)
225 double denom = std::max(sigma, 1e-10);
226 matResult.row(i) = (absRow.array() - mu) / denom;
227 }
228
229 return matResult;
230}
231
232//=============================================================================================================
233
235 const MatrixXd& matDspmData,
236 const QString& onnxModelPath,
237 int lookBack)
238{
239 int nSources = matDspmData.rows();
240 int nTimes = matDspmData.cols();
241
242 MatrixXd result = matDspmData; // copy — for t < lookBack: identity (no correction)
243
244 int nCorrectableSteps = nTimes - lookBack;
245 int reportInterval = qMax(1, nCorrectableSteps / 10); // report ~10 times
246
247 // Try to load ONNX model for LSTM inference
248 MLLIB::MlOnnxModel lstmModel;
249 bool useOrt = false;
250
251 if (!onnxModelPath.isEmpty()) {
252 if (lstmModel.load(onnxModelPath)) {
253 useOrt = true;
254 qInfo() << " [LSTM correction] ONNX model loaded — using LSTM inference.";
255 } else {
256 qWarning() << " [LSTM correction] Failed to load ONNX model — falling back to moving average.";
257 }
258 } else {
259 qInfo() << " [LSTM correction] No ONNX model path — using moving average.";
260 }
261
262 // Pre-allocate input buffer for ORT: shape [1, lookBack, nSources] (batch, seq, features)
263 // Row-major layout: [seq][features]
264 std::vector<float> inputBuf;
265 if (useOrt) {
266 inputBuf.resize(static_cast<size_t>(lookBack) * static_cast<size_t>(nSources));
267 }
268
269 // For t >= lookBack: apply temporal correction
270 for (int t = lookBack; t < nTimes; ++t) {
271 int step = t - lookBack;
272 if (step % reportInterval == 0 || t == nTimes - 1) {
273 double pct = 100.0 * (step + 1) / nCorrectableSteps;
274 qInfo().noquote() << QString(" [LSTM correction] %1% (%2/%3 time steps)")
275 .arg(pct, 0, 'f', 0).arg(step + 1).arg(nCorrectableSteps);
276 }
277
278 VectorXd prediction;
279
280 if (useOrt) {
281 // Fill input buffer: double→float, column-major→row-major
282 // Layout: inputBuf[k * nSources + s] = matDspmData(s, t - lookBack + k)
283 for (int k = 0; k < lookBack; ++k) {
284 int col = t - lookBack + k;
285 for (int s = 0; s < nSources; ++s) {
286 inputBuf[static_cast<size_t>(k) * static_cast<size_t>(nSources)
287 + static_cast<size_t>(s)] = static_cast<float>(result(s, col));
288 }
289 }
290
291 // Create MlTensor view over the pre-allocated buffer — zero-copy
292 std::vector<int64_t> inputShape = {1, static_cast<int64_t>(lookBack),
293 static_cast<int64_t>(nSources)};
294 MLLIB::MlTensor inputTensor = MLLIB::MlTensor::view(inputBuf.data(), inputShape);
295
296 // Run LSTM inference
297 MLLIB::MlTensor outputTensor = lstmModel.predict(inputTensor);
298
299 // Convert output to Eigen VectorXd
300 // Expected output shape: [1, nSources] or [nSources]
301 prediction.resize(nSources);
302 const float* outPtr = outputTensor.data();
303 for (int s = 0; s < nSources; ++s) {
304 prediction(s) = static_cast<double>(outPtr[s]);
305 }
306 } else {
307 // Moving average fallback (control estimate from paper)
308 MatrixXd window = result.middleCols(t - lookBack, lookBack);
309 prediction = window.rowwise().mean();
310 }
311
312 // Normalize prediction (Eq. 12)
313 double maxVal = prediction.cwiseAbs().maxCoeff();
314 if (maxVal > 1e-10) {
315 prediction = prediction.cwiseAbs() / maxVal;
316 }
317
318 // CMNE correction: element-wise product (Eq. 13)
319 result.col(t) = prediction.cwiseProduct(matDspmData.col(t));
320 }
321
322 return result;
323}
324
325//=============================================================================================================
326
327#ifndef WASMBUILD
328
330 const QString& fwdPath,
331 const QString& covPath,
332 const QString& epochsPath,
333 const QString& outOnnxPath,
334 const InvCMNESettings& settings,
335 const QString& gtStcPrefix,
336 int hiddenSize,
337 int numLayers,
338 int trainEpochs,
339 double learningRate,
340 int batchSize,
341 const QString& finetuneOnnxPath,
342 const QString& pythonExe)
343{
344 // Resolve training package directory (contains pyproject.toml + script)
345 // Expected layout: <app_dir>/../scripts/ml/training/cmne/
346 QString appDir = QCoreApplication::applicationDirPath();
347 QString cmneDir = QDir(appDir).absoluteFilePath(
348 QStringLiteral("../scripts/ml/training/cmne"));
349
350 // Fallback: source tree relative to working directory
351 if (!QFile::exists(QDir(cmneDir).absoluteFilePath(QStringLiteral("pyproject.toml")))) {
352 cmneDir = QStringLiteral("scripts/ml/training/cmne");
353 }
354
355 QString scriptPath = QDir(cmneDir).absoluteFilePath(QStringLiteral("train_cmne_lstm.py"));
356
357 if (!QFile::exists(scriptPath)) {
359 result.stdErr = QStringLiteral("Training script not found: ") + scriptPath;
360 qWarning() << "[InvCMNE::trainLstm]" << result.stdErr;
361 return result;
362 }
363
364 qDebug() << "[InvCMNE::trainLstm] Script:" << scriptPath;
365 qDebug() << "[InvCMNE::trainLstm] Package dir:" << cmneDir;
366
367 // Map method integer to string
368 QString methodStr;
369 switch (settings.method) {
370 case 0: methodStr = QStringLiteral("MNE"); break;
371 case 1: methodStr = QStringLiteral("dSPM"); break;
372 case 2: methodStr = QStringLiteral("sLORETA"); break;
373 case 3: methodStr = QStringLiteral("eLORETA"); break;
374 default: methodStr = QStringLiteral("dSPM"); break;
375 }
376
377 double snr = 1.0 / std::sqrt(settings.lambda2);
378
379 // Build argument list matching train_cmne_lstm.py CLI
380 QStringList args;
381 args << QStringLiteral("--fwd") << fwdPath
382 << QStringLiteral("--cov") << covPath
383 << QStringLiteral("--epochs") << epochsPath
384 << QStringLiteral("--out") << outOnnxPath
385 << QStringLiteral("--look-back") << QString::number(settings.lookBack)
386 << QStringLiteral("--method") << methodStr
387 << QStringLiteral("--snr") << QString::number(snr, 'g', 6)
388 << QStringLiteral("--hidden") << QString::number(hiddenSize)
389 << QStringLiteral("--layers") << QString::number(numLayers)
390 << QStringLiteral("--train-epochs") << QString::number(trainEpochs)
391 << QStringLiteral("--lr") << QString::number(learningRate, 'g', 6)
392 << QStringLiteral("--batch") << QString::number(batchSize);
393
394 if (!gtStcPrefix.isEmpty()) {
395 args << QStringLiteral("--gt-stc") << gtStcPrefix;
396 }
397
398 if (!finetuneOnnxPath.isEmpty()) {
399 args << QStringLiteral("--finetune") << finetuneOnnxPath;
400 }
401
402 // Configure PythonRunner with venv + pyproject.toml
403 // Venv lives inside the cmne package directory as .venv/
405 config.pythonExe = pythonExe;
406 config.venvDir = QDir(cmneDir).absoluteFilePath(QStringLiteral(".venv"));
407 config.packageDir = cmneDir;
408
409 MLLIB::MLTrainer trainer(config);
410
411 return trainer.run(scriptPath, args);
412}
413
414#endif // !WASMBUILD
InvCMNE class declaration (Contextual MNE, Dinh et al. 2021).
MLTrainer class declaration — ML-specific convenience wrapper over PythonRunner.
MlOnnxModel class declaration.
MlTensor class declaration — N-dimensional, row-major, zero-copy.
Inverse source estimation (MNE, dSPM, sLORETA, dipole fitting).
CMNE result.
Definition inv_cmne.h:75
InvSourceEstimate stcDspm
Definition inv_cmne.h:76
InvSourceEstimate stcCmne
Definition inv_cmne.h:77
Eigen::MatrixXd matKernelDspm
Definition inv_cmne.h:79
InvSourceEstimate stcLstmPredict
Definition inv_cmne.h:78
static InvCMNEResult compute(const Eigen::MatrixXd &matEvoked, const Eigen::MatrixXd &matGain, const Eigen::MatrixXd &matNoiseCov, const Eigen::MatrixXd &matSrcCov, const InvCMNESettings &settings)
Definition inv_cmne.cpp:77
static Eigen::MatrixXd applyLstmCorrection(const Eigen::MatrixXd &matDspmData, const QString &onnxModelPath, int lookBack)
Definition inv_cmne.cpp:234
static UTILSLIB::PythonRunnerResult trainLstm(const QString &fwdPath, const QString &covPath, const QString &epochsPath, const QString &outOnnxPath, const InvCMNESettings &settings, const QString &gtStcPrefix={}, int hiddenSize=256, int numLayers=1, int trainEpochs=50, double learningRate=1e-3, int batchSize=64, const QString &finetuneOnnxPath={}, const QString &pythonExe=QStringLiteral("python3"))
Definition inv_cmne.cpp:329
ONNX Runtime backed model.
bool load(const QString &path) override
MlTensor predict(const MlTensor &input) const override
N-dimensional tensor with contiguous row-major (C-order) float32 storage.
Definition ml_tensor.h:79
static MlTensor view(float *data, std::vector< int64_t > shape)
ML training script launcher.
Definition ml_trainer.h:77
UTILSLIB::PythonRunnerResult run(const QString &scriptPath, const QStringList &args={})
Script execution result container.
Script execution configuration.