MNE-CPP  0.1.9
A Framework for Electrophysiology
crosscorrelation.cpp
Go to the documentation of this file.
1 //=============================================================================================================
35 //=============================================================================================================
36 // INCLUDES
37 //=============================================================================================================
38 
39 #include "crosscorrelation.h"
40 #include "../network/networknode.h"
41 #include "../network/networkedge.h"
42 #include "../network/network.h"
43 
44 #include <utils/spectral.h>
45 
46 //=============================================================================================================
47 // QT INCLUDES
48 //=============================================================================================================
49 
50 #include <QDebug>
51 #include <QtConcurrent>
52 
53 //=============================================================================================================
54 // EIGEN INCLUDES
55 //=============================================================================================================
56 
57 #include <unsupported/Eigen/FFT>
58 
59 //=============================================================================================================
60 // USED NAMESPACES
61 //=============================================================================================================
62 
63 using namespace CONNECTIVITYLIB;
64 using namespace Eigen;
65 using namespace UTILSLIB;
66 
67 //=============================================================================================================
68 // DEFINE GLOBAL METHODS
69 //=============================================================================================================
70 
71 //=============================================================================================================
72 // DEFINE MEMBER METHODS
73 //=============================================================================================================
74 
76 {
77 }
78 
79 //=============================================================================================================
80 
82 {
83 // QElapsedTimer timer;
84 // qint64 iTime = 0;
85 // timer.start();
86 
87  #ifdef EIGEN_FFTW_DEFAULT
88  fftw_make_planner_thread_safe();
89  #endif
90 
91  Network finalNetwork("XCOR");
92 
93  if(connectivitySettings.isEmpty()) {
94  qDebug() << "CrossCorrelation::calculate - Input data is empty";
95  return finalNetwork;
96  }
97 
98  if(AbstractMetric::m_bStorageModeIsActive == false) {
99  connectivitySettings.clearIntermediateData();
100  }
101 
102  finalNetwork.setSamplingFrequency(connectivitySettings.getSamplingFrequency());
103 
104  //Create nodes
105  int rows = connectivitySettings.at(0).matData.rows();
106  RowVectorXf rowVert = RowVectorXf::Zero(3);
107 
108  for(int i = 0; i < rows; ++i) {
109  rowVert = RowVectorXf::Zero(3);
110 
111  if(connectivitySettings.getNodePositions().rows() != 0 && i < connectivitySettings.getNodePositions().rows()) {
112  rowVert(0) = connectivitySettings.getNodePositions().row(i)(0);
113  rowVert(1) = connectivitySettings.getNodePositions().row(i)(1);
114  rowVert(2) = connectivitySettings.getNodePositions().row(i)(2);
115  }
116 
117  finalNetwork.append(NetworkNode::SPtr(new NetworkNode(i, rowVert)));
118  }
119 
120  // Generate tapers
121  int iSignalLength = connectivitySettings.at(0).matData.cols();
122  int iNfft = connectivitySettings.getFFTSize();
123 
124  QPair<MatrixXd, VectorXd> tapers = Spectral::generateTapers(iSignalLength, connectivitySettings.getWindowType());
125 
126  // Compute the cross correlation in parallel
127  QMutex mutex;
128  MatrixXd matDist;
129 
130  std::function<void(ConnectivitySettings::IntermediateTrialData&)> computeLambda = [&](ConnectivitySettings::IntermediateTrialData& inputData) {
131  compute(inputData,
132  matDist,
133  mutex,
134  iNfft,
135  tapers);
136  };
137 
138 // iTime = timer.elapsed();
139 // qWarning() << "Preparation" << iTime;
140 // timer.restart();
141 
142  // Calculate connectivity matrix over epochs and average afterwards
143  QFuture<void> resultMat = QtConcurrent::map(connectivitySettings.getTrialData(),
144  computeLambda);
145  resultMat.waitForFinished();
146 
147  matDist /= connectivitySettings.size();
148 
149 // iTime = timer.elapsed();
150 // qWarning() << "ComputeSpectraPSDCSD" << iTime;
151 // timer.restart();
152 
153  //Add edges to network
154  MatrixXd matWeight(1,1);
155  QSharedPointer<NetworkEdge> pEdge;
156  int j;
157 
158  for(int i = 0; i < matDist.rows(); ++i) {
159  for(j = i; j < matDist.cols(); ++j) {
160  matWeight << matDist(i,j);
161 
162  pEdge = QSharedPointer<NetworkEdge>(new NetworkEdge(i, j, matWeight));
163 
164  finalNetwork.getNodeAt(i)->append(pEdge);
165  finalNetwork.getNodeAt(j)->append(pEdge);
166  finalNetwork.append(pEdge);
167  }
168  }
169 
170 // iTime = timer.elapsed();
171 // qWarning() << "Compute" << iTime;
172 // timer.restart();
173 
174  return finalNetwork;
175 }
176 
177 //=============================================================================================================
178 
180  MatrixXd& matDist,
181  QMutex& mutex,
182  int iNfft,
183  const QPair<MatrixXd, VectorXd>& tapers)
184 {
185 // QElapsedTimer timer;
186 // qint64 iTime = 0;
187 // timer.start();
188 
189  // Calculate tapered spectra if not available already
190  RowVectorXd vecInputFFT, rowData;
191  RowVectorXcd vecResultFreq;
192 
193  FFT<double> fft;
194  fft.SetFlag(fft.HalfSpectrum);
195 
196  int i, j;
197  int iNRows = inputData.matData.rows();
198 
199  // Calculate tapered spectra if not available already
200  // This code was copied and changed modified Utils/Spectra since we do not want to call the function due to time loss.
201  if(inputData.vecTapSpectra.isEmpty()) {
202  int iNFreqs = int(floor(iNfft / 2.0)) + 1;
203  MatrixXcd matTapSpectrum(tapers.first.rows(), iNFreqs);
204 
205  for (i = 0; i < iNRows; ++i) {
206  // Substract mean
207  rowData.array() = inputData.matData.row(i).array() - inputData.matData.row(i).mean();
208 
209  // Calculate tapered spectra
210  for(j = 0; j < tapers.first.rows(); j++) {
211  // Zero padd if necessary. The zero padding in Eigen's FFT is only working for column vectors.
212  if (rowData.cols() < iNfft) {
213  vecInputFFT.setZero(iNfft);
214  vecInputFFT.block(0,0,1,rowData.cols()) = rowData.cwiseProduct(tapers.first.row(j));;
215  } else {
216  vecInputFFT = rowData.cwiseProduct(tapers.first.row(j));
217  }
218 
219  // FFT for freq domain returning the half spectrum and multiply taper weights
220  fft.fwd(vecResultFreq, vecInputFFT, iNfft);
221  matTapSpectrum.row(j) = vecResultFreq * tapers.second(j);
222  }
223 
224  inputData.vecTapSpectra.append(matTapSpectrum);
225  }
226  }
227 
228 // iTime = timer.elapsed();
229 // qDebug() << QThread::currentThreadId() << "CrossCorrelation::compute timer - Tapered spectra:" << iTime;
230 // timer.restart();
231 
232  // Perform multiplication and transform back to time domain to find max XCOR coefficient
233  // Note that the result in time domain is mirrored around the center of the data (compared to Matlab)
234  MatrixXd matDistTrial = MatrixXd::Zero(iNRows, iNRows);
235  RowVectorXcd vecResultXCor;
236  int idx = 0;
237  double denom = tapers.second.sum();
238 
239  for(i = 0; i < inputData.vecTapSpectra.size(); ++i) {
240  vecResultFreq = inputData.vecTapSpectra.at(i).colwise().sum() / denom;
241 
242  for(j = i; j < inputData.vecTapSpectra.size(); ++j) {
243  vecResultXCor = vecResultFreq.cwiseProduct(inputData.vecTapSpectra.at(j).colwise().sum() / denom);
244 
245  fft.inv(vecInputFFT, vecResultXCor, iNfft);
246 
247  vecInputFFT.maxCoeff(&idx);
248 
249  matDistTrial(i,j) = vecInputFFT(idx);
250  }
251  }
252 
253 // iTime = timer.elapsed();
254 // qDebug() << QThread::currentThreadId() << "CrossCorrelation::compute timer - Multiplication and inv FFT:" << iTime;
255 // timer.restart();
256 
257  // Sum up weights
258  mutex.lock();
259 
260  if(matDist.rows() != matDistTrial.rows() || matDist.cols() != matDistTrial.cols()) {
261  matDist.resize(matDistTrial.rows(), matDistTrial.cols());
262  matDist.setZero();
263  }
264 
265  matDist += matDistTrial;
266 
267  mutex.unlock();
268 
269 // iTime = timer.elapsed();
270 // qDebug() << QThread::currentThreadId() << "CrossCorrelation::compute timer - Summing up matDist:" << iTime;
271 // timer.restart();
272 
273  if(!m_bStorageModeIsActive) {
274  inputData.vecTapSpectra.clear();
275  }
276 }
CONNECTIVITYLIB::ConnectivitySettings::IntermediateTrialData
Definition: connectivitysettings.h:98
spectral.h
Declaration of Spectral class.
CONNECTIVITYLIB::Network
This class holds information about a network, can compute a distance table and provide network metric...
Definition: network.h:88
crosscorrelation.h
CrossCorrelation class declaration.
CONNECTIVITYLIB::Network::append
void append(QSharedPointer< NetworkEdge > newEdge)
CONNECTIVITYLIB::Network::getNodeAt
QSharedPointer< NetworkNode > getNodeAt(int i)
Definition: network.cpp:163
CONNECTIVITYLIB::CrossCorrelation::compute
static void compute(ConnectivitySettings::IntermediateTrialData &inputData, Eigen::MatrixXd &matDist, QMutex &mutex, int iNfft, const QPair< Eigen::MatrixXd, Eigen::VectorXd > &tapers)
Definition: crosscorrelation.cpp:179
CONNECTIVITYLIB::CrossCorrelation::calculate
static Network calculate(ConnectivitySettings &connectivitySettings)
Definition: crosscorrelation.cpp:81
CONNECTIVITYLIB::Network::setSamplingFrequency
void setSamplingFrequency(float fSFreq)
Definition: network.cpp:492
CONNECTIVITYLIB::CrossCorrelation::CrossCorrelation
CrossCorrelation()
Definition: crosscorrelation.cpp:75
CONNECTIVITYLIB::NetworkEdge
This class holds an object to describe the edge of a network.
Definition: networkedge.h:79
CONNECTIVITYLIB::ConnectivitySettings
This class is a container for connectivity settings.
Definition: connectivitysettings.h:91
CONNECTIVITYLIB::NetworkNode::SPtr
QSharedPointer< NetworkNode > SPtr
Definition: networknode.h:85
CONNECTIVITYLIB::NetworkNode
This class holds an object to describe the node of a network.
Definition: networknode.h:81