InvCMNE
Namespace: INVERSELIB · Library: Inverse Library
#include <inv/inv_cmne.h>
class INVLIB::InvCMNE
Contextual Minimum Norm Estimate (CMNE) inverse solver.
Implements the algorithm from: Dinh et al. "Contextual Minimum-Norm Estimates (CMNE): A Deep Learning Method for Source Estimation in Neuroimaging", 2021.
CMNE inverse solver
Static Methods
compute(matEvoked, matGain, matNoiseCov, matSrcCov, settings)
Compute CMNE inverse solution.
Parameters:
-
matEvoked : const Eigen::MatrixXd & Evoked data (n_channels x n_times).
-
matGain : const Eigen::MatrixXd & Forward gain matrix (n_channels x n_sources).
-
matNoiseCov : const Eigen::MatrixXd & Noise covariance (n_channels x n_channels).
-
matSrcCov : const Eigen::MatrixXd & Source covariance (n_sources x n_sources, diagonal).
-
settings : const InvCMNESettings & CMNE settings.
Returns:
- InvCMNEResult —
InvCMNEResultcontaining dSPM and CMNE source estimates.
applyLstmCorrection(matDspmData, onnxModelPath, lookBack)
Apply LSTM-based temporal correction to z-scored rectified dSPM data.
Parameters:
-
matDspmData : const Eigen::MatrixXd & Z-scored rectified dSPM data (n_sources x n_times).
-
onnxModelPath : const QString & Path to ONNX model file.
-
lookBack : int Number of past time steps (k).
Returns:
- Eigen::MatrixXd — Corrected source data (n_sources x n_times).
trainLstm(fwdPath, covPath, epochsPath, outOnnxPath, settings, gtStcPrefix, hiddenSize, numLayers, trainEpochs, learningRate, batchSize, finetuneOnnxPath, pythonExe)
Train the CMNE LSTM model by invoking the Python training script.
This is a convenience wrapper that calls scripts/ml/training/train_cmne_lstm.py via PythonRunner. The heavy lifting (PyTorch LSTM training + ONNX export) happens in Python; C++ only launches the process and streams its output.
Parameters:
-
fwdPath : const QString & Path to forward solution FIFF file.
-
covPath : const QString & Path to noise covariance FIFF file.
-
epochsPath : const QString & Path to epochs FIFF file.
-
outOnnxPath : const QString & Desired output path for the ONNX model.
-
settings : const InvCMNESettings & CMNE settings (look-back, method, SNR are forwarded).
-
gtStcPrefix : const QString & Ground-truth STC prefix (optional; empty = simulation mode).
-
hiddenSize : int LSTM hidden dimension (default 256).
-
numLayers : int LSTM layers (default 1).
-
trainEpochs : int Number of training epochs (default 50).
-
learningRate : double Learning rate (default 1e-3).
-
batchSize : int Batch size (default 64).
-
finetuneOnnxPath : const QString & Existing ONNX model to fine-tune from (optional).
-
pythonExe : const QString & Python interpreter (default "python3").
Returns:
- UTILSLIB::PythonRunnerResult — PythonRunnerResult with exit code, captured output and progress.
Example
Source: src/examples/ex_clustered_inverse_mne/main.cpp
#include <fs/fs_label.h>
#include <fs/fs_surface.h>
#include <fs/fs_surfaceset.h>
#include <fs/fs_annotationset.h>
#include <fiff/fiff_evoked.h>
#include <inv/inv_source_estimate.h>
#include <inv/minimum_norm/inv_minimum_norm.h>
#include <disp3D/view/brainview.h>
#include <disp3D/model/braintreemodel.h>
#include <math/linalg.h>
#include <utils/generics/mne_logger.h>
#include <iostream>
//=============================================================================================================
// QT INCLUDES
//=============================================================================================================
#include <QApplication>
#include <QCommandLineParser>
#include <QDir>
#include <QSet>
#include <QVector3D>
//=============================================================================================================
// USED NAMESPACES
//=============================================================================================================
using namespace MNELIB;
using namespace FSLIB;
using namespace FIFFLIB;
using namespace INVLIB;
using namespace UTILSLIB;
using namespace Eigen;
//=============================================================================================================
// MAIN
//=============================================================================================================
//=============================================================================================================
/**
* The function main marks the entry point of the program.
* By default, main has the storage class extern.
*
* @param[in] argc (argument count) is an integer that indicates how many arguments were entered on the command line when the program was started.
* @param[in] argv (argument vector) is an array of pointers to arrays of character objects. The array objects are null-terminated strings, representing the arguments that were entered on the command line when the program was started.
* @return the value that was set to exit() (which is 0 if exit() is called via quit()).
*/
int main(int argc, char *argv[])
{
#ifdef STATICBUILD
// Q_INIT_RESOURCE(mne_disp3d);
#endif
qInstallMessageHandler(MNELogger::customLogWriter);
QApplication app(argc, argv);
// Command Line Parser
QCommandLineParser parser;
parser.setApplicationDescription("Clustered Inverse MNE Example");
parser.addHelpOption();
QCommandLineOption sampleFwdFileOption("fwd", "Path to the forward solution <file>.", "file", QCoreApplication::applicationDirPath() + "/../resources/data/MNE-sample-data/MEG/sample/sample_audvis-meg-eeg-oct-6-fwd.fif");
QCommandLineOption sampleCovFileOption("cov", "Path to the covariance <file>.", "file", QCoreApplication::applicationDirPath() + "/../resources/data/MNE-sample-data/MEG/sample/sample_audvis-cov.fif");
QCommandLineOption sampleEvokedFileOption("ave", "Path to the evoked/average <file>.", "file", QCoreApplication::applicationDirPath() + "/../resources/data/MNE-sample-data/MEG/sample/sample_audvis-ave.fif");
QCommandLineOption snrOption("snr", "The <snr> value used for computation.", "snr", "1.0");//3.0;//0.1;//3.0;
QCommandLineOption methodOption("method", "Inverse estimation <method>, i.e., 'MNE', 'dSPM' or 'sLORETA'.", "method", "dSPM");//"MNE" | "dSPM" | "sLORETA"
QCommandLineOption invFileOption("invOut", "Path to inverse <file>, which is to be written.", "file", "");
QCommandLineOption stcFileOption("stcOut", "Path to stc <file>, which is to be written.", "file", "");
QCommandLineOption surfOption("surfType", "FsSurface type <type>.", "type", "orig");
QCommandLineOption annotOption("annotType", "FsAnnotation type <type>.", "type", "aparc.a2009s");
QCommandLineOption hemiOption("hemi", "Selected hemisphere <hemi>.", "hemi", "2");
QCommandLineOption subjectOption("subject", "Selected subject <subject>.", "subject", "sample");
QCommandLineOption subjectPathOption("subjectPath", "Selected subject path <subjectPath>.", "subjectPath", QCoreApplication::applicationDirPath() + "/../resources/data/MNE-sample-data/subjects");
parser.addOption(sampleFwdFileOption);
parser.addOption(sampleCovFileOption);
parser.addOption(sampleEvokedFileOption);
parser.addOption(snrOption);
parser.addOption(methodOption);
parser.addOption(invFileOption);
parser.addOption(stcFileOption);
parser.addOption(surfOption);
parser.addOption(annotOption);
parser.addOption(hemiOption);
parser.addOption(subjectOption);
parser.addOption(subjectPathOption);
parser.process(app);
//Parse inputs
QFile t_fileFwd(parser.value(sampleFwdFileOption));
QFile t_fileCov(parser.value(sampleCovFileOption));
QFile t_fileEvoked(parser.value(sampleEvokedFileOption));
double snr = parser.value(snrOption).toDouble();
QString method = parser.value(methodOption);
QString t_sFileNameClusteredInv = parser.value(invFileOption);
QString t_sFileNameStc = parser.value(stcFileOption);
double lambda2 = 1.0 / pow(snr, 2);
FsSurfaceSet t_surfSet (parser.value(subjectOption), parser.value(hemiOption).toInt(), parser.value(surfOption), parser.value(subjectPathOption));
FsAnnotationSet t_annotationSet (parser.value(subjectOption), parser.value(hemiOption).toInt(), parser.value(annotOption), parser.value(subjectPathOption));
qDebug() << "Start calculation with: SNR" << snr << "; Lambda" << lambda2 << "; Method" << method << "; stc:" << t_sFileNameStc;
// Load data
fiff_int_t setno = 0;
QPair<float, float> baseline(-1.0f, -1.0f);
FiffEvoked evoked(t_fileEvoked, setno, baseline);
if(evoked.isEmpty())
return 1;
std::cout << "evoked first " << evoked.first << "; last " << evoked.last << std::endl;
MNEForwardSolution t_Fwd(t_fileFwd);
if(t_Fwd.isEmpty())
return 1;
FiffCov noise_cov(t_fileCov);
// regularize noise covariance
noise_cov = noise_cov.regularize(evoked.info, 0.05, 0.05, 0.1, true);
//
// Cluster forward solution;
//
MNEForwardSolution t_clusteredFwd = t_Fwd.cluster_forward_solution(t_annotationSet, 20);//40);
// std::cout << "Size " << t_clusteredFwd.sol->data.rows() << " x " << t_clusteredFwd.sol->data.cols() << std::endl;
// std::cout << "Clustered Fwd:\n" << t_clusteredFwd.sol->data.row(0) << std::endl;
//
// make an inverse operators
//
FiffInfo info = evoked.info;
MNEInverseOperator inverse_operator(info, t_clusteredFwd, noise_cov, 0.2f, 0.8f);
//
// save clustered inverse
//
if(!t_sFileNameClusteredInv.isEmpty())
{
QFile t_fileClusteredInverse(t_sFileNameClusteredInv);
inverse_operator.write(t_fileClusteredInverse);
}
//
// Compute inverse solution
//
InvMinimumNorm minimumNorm(inverse_operator, lambda2, method);
InvSourceEstimate sourceEstimate = minimumNorm.calculateInverse(evoked);
if(sourceEstimate.isEmpty())
return 1;
// View activation time-series
std::cout << "\nsourceEstimate:\n" << sourceEstimate.data.block(0,0,10,10) << std::endl;
std::cout << "time\n" << sourceEstimate.times.block(0,0,1,10) << std::endl;
std::cout << "timeMin\n" << sourceEstimate.times[0] << std::endl;
std::cout << "timeMax\n" << sourceEstimate.times[sourceEstimate.times.size()-1] << std::endl;
std::cout << "time step\n" << sourceEstimate.tstep << std::endl;
//Condition Numbers
// MatrixXd mags(102, t_Fwd.sol->data.cols());
// qint32 count = 0;
// for(qint32 i = 2; i < 306; i += 3)
// {
// mags.row(count) = t_Fwd.sol->data.row(i);
// ++count;
// }
// MatrixXd magsClustered(102, t_clusteredFwd.sol->data.cols());
// count = 0;
// for(qint32 i = 2; i < 306; i += 3)
// {
// magsClustered.row(count) = t_clusteredFwd.sol->data.row(i);
// ++count;
// }
// MatrixXd grads(204, t_Fwd.sol->data.cols());
// count = 0;
// for(qint32 i = 0; i < 306; i += 3)
// {
// grads.row(count) = t_Fwd.sol->data.row(i);
// ++count;
// grads.row(count) = t_Fwd.sol->data.row(i+1);
// ++count;
// }
// MatrixXd gradsClustered(204, t_clusteredFwd.sol->data.cols());
// count = 0;
// for(qint32 i = 0; i < 306; i += 3)
// {
// gradsClustered.row(count) = t_clusteredFwd.sol->data.row(i);
// ++count;
// gradsClustered.row(count) = t_clusteredFwd.sol->data.row(i+1);
// ++count;
// }
VectorXd s;
double t_dConditionNumber = Linalg::getConditionNumber(t_Fwd.sol->data, s);
double t_dConditionNumberClustered = Linalg::getConditionNumber(t_clusteredFwd.sol->data, s);
std::cout << "Condition Number:\n" << t_dConditionNumber << std::endl;
std::cout << "Clustered Condition Number:\n" << t_dConditionNumberClustered << std::endl;
std::cout << "ForwardSolution" << t_Fwd.sol->data.block(0,0,10,10) << std::endl;
std::cout << "Clustered ForwardSolution" << t_clusteredFwd.sol->data.block(0,0,10,10) << std::endl;
// double t_dConditionNumberMags = Linalg::getConditionNumber(mags, s);
// double t_dConditionNumberMagsClustered = Linalg::getConditionNumber(magsClustered, s);
// std::cout << "Condition Number Magnetometers:\n" << t_dConditionNumberMags << std::endl;
// std::cout << "Clustered Condition Number Magnetometers:\n" << t_dConditionNumberMagsClustered << std::endl;
// double t_dConditionNumberGrads = Linalg::getConditionNumber(grads, s);
// double t_dConditionNumberGradsClustered = Linalg::getConditionNumber(gradsClustered, s);
// std::cout << "Condition Number Gradiometers:\n" << t_dConditionNumberGrads << std::endl;
// std::cout << "Clustered Condition Number Gradiometers:\n" << t_dConditionNumberGradsClustered << std::endl;
//Source Estimate end
//########################################################################################
// //only one time point - P100
// qint32 sample = 0;
// for(qint32 i = 0; i < sourceEstimate.times.size(); ++i)
// {
// if(sourceEstimate.times(i) >= 0)
// {
// sample = i;
// break;
// }
// }
// sample += (qint32)ceil(0.106/sourceEstimate.tstep); //100ms
// sourceEstimate = sourceEstimate.reduce(sample, 1);
// Write source estimate to temp files for visualization
int nVertLh = t_clusteredFwd.src[0].nuse;
InvSourceEstimate stcLh, stcRh;
stcLh.data = sourceEstimate.data.topRows(nVertLh);
stcLh.vertices = sourceEstimate.vertices.head(nVertLh);
stcLh.tmin = sourceEstimate.tmin;
stcLh.tstep = sourceEstimate.tstep;
stcLh.times = sourceEstimate.times;
stcRh.data = sourceEstimate.data.bottomRows(sourceEstimate.data.rows() - nVertLh);
stcRh.vertices = sourceEstimate.vertices.tail(sourceEstimate.vertices.size() - nVertLh);
stcRh.tmin = sourceEstimate.tmin;
stcRh.tstep = sourceEstimate.tstep;
stcRh.times = sourceEstimate.times;
QString tmpDir = QDir::tempPath();
QString lhStcPath = tmpDir + "/mnecpp_clustered_inv_mne-lh.stc";
QString rhStcPath = tmpDir + "/mnecpp_clustered_inv_mne-rh.stc";
QFile lhStcFile(lhStcPath);
stcLh.write(lhStcFile);
QFile rhStcFile(rhStcPath);
stcRh.write(rhStcFile);
BrainView *pBrainView = new BrainView();
BrainTreeModel *pModel = new BrainTreeModel();
pBrainView->setModel(pModel);
// Add hemisphere surfaces
for (auto it = t_surfSet.data().constBegin(); it != t_surfSet.data().constEnd(); ++it) {
int hIdx = it.key();
QString hemi = (it.value().hemi() == 0) ? "lh" : "rh";
QString surfType = it.value().surf().isEmpty() ? "inflated" : it.value().surf();
pModel->addSurface(parser.value(subjectOption), hemi, surfType, it.value());
if (t_annotationSet.size() > hIdx)
pModel->addAnnotation(parser.value(subjectOption), hemi, t_annotationSet[hIdx]);
}
// Load source estimate and start streaming when loaded
QObject::connect(pBrainView, &BrainView::sourceEstimateLoaded, [&](int /*nTimePoints*/) {
pBrainView->setSourceColormap("Hot");
pBrainView->setSourceThresholds(0.0f, 0.5f, 10.0f);
pBrainView->setRealtimeLooping(true);
pBrainView->setRealtimeInterval(17);
pBrainView->startRealtimeStreaming();
});
pBrainView->loadSourceEstimate(lhStcPath, rhStcPath);
if(!t_sFileNameStc.isEmpty())
{
QFile t_fileClusteredStc(t_sFileNameStc);
sourceEstimate.write(t_fileClusteredStc);
}
pBrainView->show();
return app.exec();//1;
}
Authors of this file
- Christoph Dinh <christoph.dinh@mne-cpp.org>