v2.0.0
Loading...
Searching...
No Matches
inv_mxne.cpp
Go to the documentation of this file.
1//=============================================================================================================
34
35//=============================================================================================================
36// INCLUDES
37//=============================================================================================================
38
39#include "inv_mxne.h"
40
41//=============================================================================================================
42// STL INCLUDES
43//=============================================================================================================
44
45#include <cmath>
46#include <algorithm>
47
48//=============================================================================================================
49// USED NAMESPACES
50//=============================================================================================================
51
52using namespace INVLIB;
53using namespace Eigen;
54
55//=============================================================================================================
56// DEFINE MEMBER METHODS
57//=============================================================================================================
58
60 const MatrixXd& matGain,
61 const MatrixXd& matData,
62 double alpha,
63 int nIterations,
64 double tolerance)
65{
66 const int nChannels = static_cast<int>(matGain.rows());
67 const int nSources = static_cast<int>(matGain.cols());
68 const int nTimes = static_cast<int>(matData.cols());
69
70 // Precompute G^T * G and G^T * M
71 MatrixXd matGtG = matGain.transpose() * matGain;
72 MatrixXd matGtM = matGain.transpose() * matData;
73
74 // Initialize weights to 1
75 VectorXd vecWeights = VectorXd::Ones(nSources);
76 VectorXd vecWeightsOld = vecWeights;
77
78 // Active set: all sources initially active
79 std::vector<int> activeIdx(nSources);
80 std::iota(activeIdx.begin(), activeIdx.end(), 0);
81
82 // Full source solution
83 MatrixXd matX = MatrixXd::Zero(nSources, nTimes);
84
85 int actualIterations = 0;
86
87 for (int iter = 0; iter < nIterations; ++iter) {
88 actualIterations = iter + 1;
89
90 const int nActive = static_cast<int>(activeIdx.size());
91 if (nActive == 0)
92 break;
93
94 // Extract active columns of G^T*G and G^T*M
95 MatrixXd matGtG_active(nActive, nActive);
96 MatrixXd matGtM_active(nActive, nTimes);
97
98 for (int i = 0; i < nActive; ++i) {
99 matGtM_active.row(i) = matGtM.row(activeIdx[i]);
100 for (int j = 0; j < nActive; ++j) {
101 matGtG_active(i,j) = matGtG(activeIdx[i], activeIdx[j]);
102 }
103 }
104
105 // Build diagonal weight matrix W = diag(1/w_i^2)
106 VectorXd vecWdiag(nActive);
107 for (int i = 0; i < nActive; ++i) {
108 double w = vecWeights(activeIdx[i]);
109 vecWdiag(i) = 1.0 / (w * w);
110 }
111
112 // Solve (G^T*G + alpha*W) * X_active = G^T*M
113 MatrixXd matLhs = matGtG_active;
114 matLhs.diagonal() += alpha * vecWdiag;
115
116 MatrixXd matX_active = matLhs.ldlt().solve(matGtM_active);
117
118 // Write back to full solution
119 matX.setZero();
120 for (int i = 0; i < nActive; ++i) {
121 matX.row(activeIdx[i]) = matX_active.row(i);
122 }
123
124 // Update weights: w_i = max(||X_i||_2, 1e-10)
125 vecWeightsOld = vecWeights;
126 for (int i = 0; i < nSources; ++i) {
127 vecWeights(i) = std::max(matX.row(i).norm(), 1e-10);
128 }
129
130 // Active set pruning: keep sources with w_i >= 1e-8
131 std::vector<int> newActive;
132 newActive.reserve(nActive);
133 for (int i = 0; i < nSources; ++i) {
134 if (vecWeights(i) >= 1e-8) {
135 newActive.push_back(i);
136 }
137 }
138 activeIdx = newActive;
139
140 // Check convergence
141 double maxChange = 0.0;
142 for (int idx : activeIdx) {
143 maxChange = std::max(maxChange, std::abs(vecWeights(idx) - vecWeightsOld(idx)));
144 }
145 if (maxChange < tolerance)
146 break;
147 }
148
149 // Build result
150 InvMxneResult result;
151 result.nIterations = actualIterations;
152
153 // Collect active vertices and build sparse output
154 QVector<int> finalActive;
155 for (int i = 0; i < nSources; ++i) {
156 if (matX.row(i).norm() >= 1e-8) {
157 finalActive.append(i);
158 }
159 }
160 result.activeVertices = finalActive;
161
162 // Build source estimate with active rows only
163 const int nActiveFinal = finalActive.size();
164 MatrixXd matActiveSol(nActiveFinal, nTimes);
165 VectorXi vecActiveVerts(nActiveFinal);
166 for (int i = 0; i < nActiveFinal; ++i) {
167 matActiveSol.row(i) = matX.row(finalActive[i]);
168 vecActiveVerts(i) = finalActive[i];
169 }
170
171 result.stc = InvSourceEstimate(matActiveSol, vecActiveVerts, 0.0f, 1.0f);
173
174 // Compute residual norm ||M - G*X||_F
175 MatrixXd matResidual = matData - matGain * matX;
176 result.residualNorm = matResidual.norm();
177
178 return result;
179}
InvMxne class declaration.
Inverse source estimation (MNE, dSPM, sLORETA, dipole fitting).
QVector< int > activeVertices
Definition inv_mxne.h:71
InvSourceEstimate stc
Definition inv_mxne.h:70
static InvMxneResult compute(const Eigen::MatrixXd &matGain, const Eigen::MatrixXd &matData, double alpha, int nIterations=50, double tolerance=1e-6)
Definition inv_mxne.cpp:59