v2.0.0
Loading...
Searching...
No Matches
extended_infomax.cpp
Go to the documentation of this file.
1//=============================================================================================================
34
35//=============================================================================================================
36// INCLUDES
37//=============================================================================================================
38
39#include "extended_infomax.h"
40
41//=============================================================================================================
42// EIGEN INCLUDES
43//=============================================================================================================
44
45#include <Eigen/Eigenvalues>
46
47//=============================================================================================================
48// STL INCLUDES
49//=============================================================================================================
50
51#include <random>
52#include <cmath>
53
54//=============================================================================================================
55// USED NAMESPACES
56//=============================================================================================================
57
58using namespace UTILSLIB;
59using namespace Eigen;
60
61//=============================================================================================================
62// DEFINE MEMBER METHODS
63//=============================================================================================================
64
66 const MatrixXd& matData,
67 int nComponents,
68 int maxIterations,
69 double learningRate,
70 double tolerance,
71 bool extendedMode,
72 unsigned int seed)
73{
74 const int nChannels = static_cast<int>(matData.rows());
75 const int nTimes = static_cast<int>(matData.cols());
76
77 if (nComponents < 0) {
78 nComponents = nChannels;
79 }
80
81 //=========================================================================================================
82 // Step 1: Mean removal
83 //=========================================================================================================
84 VectorXd vecMean = matData.rowwise().mean();
85 MatrixXd matCentered = matData.colwise() - vecMean;
86
87 //=========================================================================================================
88 // Step 2: PCA whitening
89 //=========================================================================================================
90 MatrixXd matCov = (matCentered * matCentered.transpose()) / static_cast<double>(nTimes - 1);
91
92 SelfAdjointEigenSolver<MatrixXd> eigSolver(matCov);
93 VectorXd vecEigVals = eigSolver.eigenvalues();
94 MatrixXd matEigVecs = eigSolver.eigenvectors();
95
96 // Eigen returns eigenvalues in ascending order; take the top nComponents (largest)
97 VectorXd vecTopEigVals = vecEigVals.tail(nComponents).reverse();
98 MatrixXd matTopEigVecs = matEigVecs.rightCols(nComponents).rowwise().reverse();
99
100 // Whitening: P = D^{-1/2} * V^T
101 VectorXd vecInvSqrtEig = vecTopEigVals.array().sqrt().inverse();
102 MatrixXd matWhitening = vecInvSqrtEig.asDiagonal() * matTopEigVecs.transpose();
103
104 // Dewhitening: P_inv = V * D^{1/2}
105 VectorXd vecSqrtEig = vecTopEigVals.array().sqrt();
106 MatrixXd matDewhitening = matTopEigVecs * vecSqrtEig.asDiagonal();
107
108 // Whitened data
109 MatrixXd matWhite = matWhitening * matCentered;
110
111 //=========================================================================================================
112 // Step 3: Initialize weights
113 //=========================================================================================================
114 MatrixXd matW = MatrixXd::Identity(nComponents, nComponents);
115
116 if (seed != 0) {
117 std::mt19937 gen(seed);
118 std::normal_distribution<double> dist(0.0, 0.01);
119 for (int i = 0; i < nComponents; ++i) {
120 for (int j = 0; j < nComponents; ++j) {
121 if (i != j) {
122 matW(i, j) = dist(gen);
123 }
124 }
125 }
126 }
127
128 //=========================================================================================================
129 // Step 4: Main iteration loop
130 //=========================================================================================================
131 InfomaxResult result;
132 result.converged = false;
133 result.nIterations = 0;
134
135 const double dInvN = 1.0 / static_cast<double>(nTimes);
136 MatrixXd matIdentity = MatrixXd::Identity(nComponents, nComponents);
137
138 // Learning rate annealing: the natural gradient has a nonzero steady-state
139 // when the assumed nonlinearity does not perfectly match the true source
140 // distribution. Geometric decay lets the step shrink toward zero so the
141 // convergence criterion can fire.
142 double dCurrentLR = learningRate;
143 constexpr double dAnnealFactor = 0.998;
144
145 for (int iter = 0; iter < maxIterations; ++iter) {
146 // Compute sources
147 MatrixXd matSources = matW * matWhite;
148
149 // Estimate sign vector for extended mode
150 VectorXd vecSigns = VectorXd::Ones(nComponents);
151 if (extendedMode) {
152 vecSigns = estimateSignVector(matSources);
153 }
154
155 // Compute nonlinearity
156 MatrixXd matY(nComponents, nTimes);
157 for (int i = 0; i < nComponents; ++i) {
158 if (vecSigns(i) > 0) {
159 // Super-Gaussian: g(u) = -tanh(u)
160 matY.row(i) = -matSources.row(i).array().tanh();
161 } else {
162 // Sub-Gaussian: g(u) = tanh(u) - u
163 matY.row(i) = matSources.row(i).array().tanh() - matSources.row(i).array();
164 }
165 }
166
167 // Natural gradient: dW = lr * (I + Y * S^T / n_times) * W
168 MatrixXd matGrad = matIdentity + (matY * matSources.transpose()) * dInvN;
169 MatrixXd matDW = dCurrentLR * matGrad * matW;
170
171 matW += matDW;
172 result.nIterations = iter + 1;
173
174 // Convergence: squared Frobenius norm of the weight update
175 // (matches MNE-Python's criterion). With learning rate annealing
176 // the update shrinks each iteration.
177 double dChange = matDW.squaredNorm();
178 if (dChange < tolerance) {
179 result.converged = true;
180 break;
181 }
182
183 dCurrentLR *= dAnnealFactor;
184 }
185
186 //=========================================================================================================
187 // Step 5: Compute output matrices
188 //=========================================================================================================
189 // Unmixing in original sensor space: W_total = W * P
190 result.matUnmixing = matW * matWhitening;
191
192 // Mixing matrix: pseudo-inverse of unmixing
193 result.matMixing = result.matUnmixing.completeOrthogonalDecomposition().pseudoInverse();
194
195 // Source activations
196 result.matSources = result.matUnmixing * matCentered;
197
198 return result;
199}
200
201//=============================================================================================================
202
203VectorXd ExtendedInfomax::estimateSignVector(const MatrixXd& matSources)
204{
205 const int nComponents = static_cast<int>(matSources.rows());
206 const int nTimes = static_cast<int>(matSources.cols());
207 const double dInvN = 1.0 / static_cast<double>(nTimes);
208
209 VectorXd vecSigns(nComponents);
210
211 for (int i = 0; i < nComponents; ++i) {
212 double dMean = matSources.row(i).mean();
213 ArrayXd arrCentered = matSources.row(i).array() - dMean;
214 double dM2 = (arrCentered.square()).sum() * dInvN;
215 double dM4 = (arrCentered.square().square()).sum() * dInvN;
216
217 // Excess kurtosis: m4/m2^2 - 3
218 double dKurtosis = (dM2 > 0.0) ? (dM4 / (dM2 * dM2)) - 3.0 : 0.0;
219
220 vecSigns(i) = (dKurtosis > 0.0) ? 1.0 : -1.0;
221 }
222
223 return vecSigns;
224}
ExtendedInfomax class declaration.
Shared utilities (I/O helpers, spectral analysis, layout management, warp algorithms).
Eigen::MatrixXd matUnmixing
static InfomaxResult compute(const Eigen::MatrixXd &matData, int nComponents=-1, int maxIterations=200, double learningRate=0.001, double tolerance=1e-7, bool extendedMode=true, unsigned int seed=0)