44 const MatrixXd& matData,
52 const int nChannels =
static_cast<int>(matData.rows());
53 const int nTimes =
static_cast<int>(matData.cols());
55 if (nComponents < 0) {
56 nComponents = nChannels;
62 VectorXd vecMean = matData.rowwise().mean();
63 MatrixXd matCentered = matData.colwise() - vecMean;
68 MatrixXd matCov = (matCentered * matCentered.transpose()) /
static_cast<double>(nTimes - 1);
70 SelfAdjointEigenSolver<MatrixXd> eigSolver(matCov);
71 VectorXd vecEigVals = eigSolver.eigenvalues();
72 MatrixXd matEigVecs = eigSolver.eigenvectors();
75 VectorXd vecTopEigVals = vecEigVals.tail(nComponents).reverse();
76 MatrixXd matTopEigVecs = matEigVecs.rightCols(nComponents).rowwise().reverse();
79 VectorXd vecInvSqrtEig = vecTopEigVals.array().sqrt().inverse();
80 MatrixXd matWhitening = vecInvSqrtEig.asDiagonal() * matTopEigVecs.transpose();
83 VectorXd vecSqrtEig = vecTopEigVals.array().sqrt();
84 MatrixXd matDewhitening = matTopEigVecs * vecSqrtEig.asDiagonal();
87 MatrixXd matWhite = matWhitening * matCentered;
92 MatrixXd matW = MatrixXd::Identity(nComponents, nComponents);
95 std::mt19937 gen(seed);
96 std::normal_distribution<double> dist(0.0, 0.01);
97 for (
int i = 0; i < nComponents; ++i) {
98 for (
int j = 0; j < nComponents; ++j) {
100 matW(i, j) = dist(gen);
113 const double dInvN = 1.0 /
static_cast<double>(nTimes);
114 MatrixXd matIdentity = MatrixXd::Identity(nComponents, nComponents);
120 double dCurrentLR = learningRate;
121 constexpr double dAnnealFactor = 0.998;
123 for (
int iter = 0; iter < maxIterations; ++iter) {
125 MatrixXd matSources = matW * matWhite;
128 VectorXd vecSigns = VectorXd::Ones(nComponents);
130 vecSigns = estimateSignVector(matSources);
134 MatrixXd matY(nComponents, nTimes);
135 for (
int i = 0; i < nComponents; ++i) {
136 if (vecSigns(i) > 0) {
138 matY.row(i) = -matSources.row(i).array().tanh();
141 matY.row(i) = matSources.row(i).array().tanh() - matSources.row(i).array();
146 MatrixXd matGrad = matIdentity + (matY * matSources.transpose()) * dInvN;
147 MatrixXd matDW = dCurrentLR * matGrad * matW;
155 double dChange = matDW.squaredNorm();
156 if (dChange < tolerance) {
161 dCurrentLR *= dAnnealFactor;
static InfomaxResult compute(const Eigen::MatrixXd &matData, int nComponents=-1, int maxIterations=200, double learningRate=0.001, double tolerance=1e-7, bool extendedMode=true, unsigned int seed=0)