MNE-CPP  0.1.9
A Framework for Electrophysiology
weightedphaselagindex.cpp
Go to the documentation of this file.
1 //=============================================================================================================
39 //=============================================================================================================
40 // INCLUDES
41 //=============================================================================================================
42 
43 #include "weightedphaselagindex.h"
44 #include "../network/networknode.h"
45 #include "../network/networkedge.h"
46 #include "../network/network.h"
47 
48 #include <utils/spectral.h>
49 
50 //=============================================================================================================
51 // QT INCLUDES
52 //=============================================================================================================
53 
54 #include <QDebug>
55 #include <QtConcurrent>
56 
57 //=============================================================================================================
58 // EIGEN INCLUDES
59 //=============================================================================================================
60 
61 #include <unsupported/Eigen/FFT>
62 
63 //=============================================================================================================
64 // USED NAMESPACES
65 //=============================================================================================================
66 
67 using namespace CONNECTIVITYLIB;
68 using namespace Eigen;
69 using namespace UTILSLIB;
70 
71 //=============================================================================================================
72 // DEFINE GLOBAL METHODS
73 //=============================================================================================================
74 
75 //=============================================================================================================
76 // DEFINE MEMBER METHODS
77 //=============================================================================================================
78 
81 {
82 }
83 
84 //*******************************************************************************************************
85 
87 {
88 // QElapsedTimer timer;
89 // qint64 iTime = 0;
90 // timer.start();
91 
92  Network finalNetwork("WPLI");
93 
94  if(connectivitySettings.isEmpty()) {
95  qWarning() << "WeightedPhaseLagIndex::calculate - Input data is empty";
96  return finalNetwork;
97  }
98 
99  if(AbstractMetric::m_bStorageModeIsActive == false) {
100  connectivitySettings.clearIntermediateData();
101  }
102 
103  finalNetwork.setSamplingFrequency(connectivitySettings.getSamplingFrequency());
104 
105  #ifdef EIGEN_FFTW_DEFAULT
106  fftw_make_planner_thread_safe();
107  #endif
108 
109  //Create nodes
110  int rows = connectivitySettings.at(0).matData.rows();
111  RowVectorXf rowVert = RowVectorXf::Zero(3);
112 
113  for(int i = 0; i < rows; ++i) {
114  rowVert = RowVectorXf::Zero(3);
115 
116  if(connectivitySettings.getNodePositions().rows() != 0 && i < connectivitySettings.getNodePositions().rows()) {
117  rowVert(0) = connectivitySettings.getNodePositions().row(i)(0);
118  rowVert(1) = connectivitySettings.getNodePositions().row(i)(1);
119  rowVert(2) = connectivitySettings.getNodePositions().row(i)(2);
120  }
121 
122  finalNetwork.append(NetworkNode::SPtr(new NetworkNode(i, rowVert)));
123  }
124 
125  // Check that iNfft >= signal length
126  int iSignalLength = connectivitySettings.at(0).matData.cols();
127  int iNfft = connectivitySettings.getFFTSize();
128 
129  // Generate tapers
130  QPair<MatrixXd, VectorXd> tapers = Spectral::generateTapers(iSignalLength, connectivitySettings.getWindowType());
131 
132  // Initialize
133  int iNRows = connectivitySettings.at(0).matData.rows();
134  int iNFreqs = int(floor(iNfft / 2.0)) + 1;
135 
136  // Check if start and bin amount need to be reset to full spectrum
137  if(m_iNumberBinStart == -1 ||
138  m_iNumberBinAmount == -1 ||
139  m_iNumberBinStart > iNFreqs ||
140  m_iNumberBinAmount > iNFreqs ||
141  m_iNumberBinAmount + m_iNumberBinStart > iNFreqs) {
142  qDebug() << "WeightedPhaseLagIndex::calculate - Resetting to full spectrum";
143  AbstractMetric::m_iNumberBinStart = 0;
144  AbstractMetric::m_iNumberBinAmount = iNFreqs;
145  }
146 
147  // Pass information about the FFT length. Use iNFreqs because we only use the half spectrum
148  finalNetwork.setFFTSize(iNFreqs);
149  finalNetwork.setUsedFreqBins(AbstractMetric::m_iNumberBinAmount);
150 
151  QMutex mutex;
152 
153  std::function<void(ConnectivitySettings::IntermediateTrialData&)> computeLambda = [&](ConnectivitySettings::IntermediateTrialData& inputData) {
154  compute(inputData,
155  connectivitySettings.getIntermediateSumData().vecPairCsdSum,
156  connectivitySettings.getIntermediateSumData().vecPairCsdImagAbsSum,
157  mutex,
158  iNRows,
159  iNFreqs,
160  iNfft,
161  tapers);
162  };
163 
164 // iTime = timer.elapsed();
165 // qWarning() << "Preparation" << iTime;
166 // timer.restart();
167 
168  // Compute WPLI in parallel for all trials
169  QFuture<void> result = QtConcurrent::map(connectivitySettings.getTrialData(),
170  computeLambda);
171  result.waitForFinished();
172 
173 // iTime = timer.elapsed();
174 // qWarning() << "ComputeSpectraPSDCSD" << iTime;
175 // timer.restart();
176 
177  // Compute WPLI
178  computeWPLI(connectivitySettings,
179  finalNetwork);
180 
181 // iTime = timer.elapsed();
182 // qWarning() << "Compute" << iTime;
183 // timer.restart();
184 
185  return finalNetwork;
186 }
187 
188 //=============================================================================================================
189 
191  QVector<QPair<int,MatrixXcd> >& vecPairCsdSum,
192  QVector<QPair<int,MatrixXd> >& vecPairCsdImagAbsSum,
193  QMutex& mutex,
194  int iNRows,
195  int iNFreqs,
196  int iNfft,
197  const QPair<MatrixXd, VectorXd>& tapers)
198 {
199 // QElapsedTimer timer;
200 // qint64 iTime = 0;
201 // timer.start();
202 
203  if(inputData.vecPairCsd.size() == iNRows &&
204  inputData.vecPairCsdImagAbs.size() == iNRows ) {
205  //qDebug() << "WeightedPhaseLagIndex::compute - vecPairCsd and vecPairCsdImagAbs were already computed for this trial.";
206  return;
207  }
208 
209  int i,j;
210 
211  // Calculate tapered spectra if not available already
212  // This code was copied and changed modified Utils/Spectra since we do not want to call the function due to time loss.
213  if(inputData.vecTapSpectra.size() != iNRows) {
214  inputData.vecTapSpectra.clear();
215 
216  RowVectorXd vecInputFFT, rowData;
217  RowVectorXcd vecTmpFreq;
218 
219  MatrixXcd matTapSpectrum(tapers.first.rows(), iNFreqs);
220 
221  FFT<double> fft;
222  fft.SetFlag(fft.HalfSpectrum);
223 
224  for (i = 0; i < iNRows; ++i) {
225  // Substract mean
226  rowData.array() = inputData.matData.row(i).array() - inputData.matData.row(i).mean();
227 
228  // Calculate tapered spectra if not available already
229  for(j = 0; j < tapers.first.rows(); j++) {
230  // Zero padd if necessary. The zero padding in Eigen's FFT is only working for column vectors.
231  if (rowData.cols() < iNfft) {
232  vecInputFFT.setZero(iNfft);
233  vecInputFFT.block(0,0,1,rowData.cols()) = rowData.cwiseProduct(tapers.first.row(j));;
234  } else {
235  vecInputFFT = rowData.cwiseProduct(tapers.first.row(j));
236  }
237 
238  // FFT for freq domain returning the half spectrum and multiply taper weights
239  fft.fwd(vecTmpFreq, vecInputFFT, iNfft);
240  matTapSpectrum.row(j) = vecTmpFreq * tapers.second(j);
241  }
242 
243  inputData.vecTapSpectra.append(matTapSpectrum);
244  }
245 
246 // iTime = timer.elapsed();
247 // qWarning() << "WeightedPhaseLagIndex::compute timer - Compute spectra:" << iTime;
248 // timer.restart();
249  }
250 
251  // Compute CSD
252  if(inputData.vecPairCsd.isEmpty()) {
253  double denomCSD = sqrt(tapers.second.cwiseAbs2().sum()) * sqrt(tapers.second.cwiseAbs2().sum()) / 2.0;
254  bool bNfftEven = false;
255  if (iNfft % 2 == 0){
256  bNfftEven = true;
257  }
258 
259  MatrixXcd matCsd = MatrixXcd(iNRows, m_iNumberBinAmount);
260 
261  for (i = 0; i < iNRows; ++i) {
262  for (j = i; j < iNRows; ++j) {
263  // Compute CSD (average over tapers if necessary)
264  matCsd.row(j) = inputData.vecTapSpectra.at(i).block(0,m_iNumberBinStart,inputData.vecTapSpectra.at(i).rows(),m_iNumberBinAmount).cwiseProduct(inputData.vecTapSpectra.at(j).block(0,m_iNumberBinStart,inputData.vecTapSpectra.at(j).rows(),m_iNumberBinAmount).conjugate()).colwise().sum() / denomCSD;
265 
266  // Divide first and last element by 2 due to half spectrum
267  if(m_iNumberBinStart == 0) {
268  matCsd.row(j)(0) /= 2.0;
269  }
270 
271  if(bNfftEven && m_iNumberBinStart + m_iNumberBinAmount >= iNFreqs) {
272  matCsd.row(j).tail(1) /= 2.0;
273  }
274  }
275 
276  inputData.vecPairCsd.append(QPair<int,MatrixXcd>(i,matCsd));
277  inputData.vecPairCsdImagAbs.append(QPair<int,MatrixXd>(i,matCsd.imag().cwiseAbs()));
278  }
279 
280 // iTime = timer.elapsed();
281 // qWarning() << "WeightedPhaseLagIndex::compute timer - Compute CSD and Imag CSD:" << iTime;
282 // timer.restart();
283 
284  mutex.lock();
285 
286  if(vecPairCsdSum.isEmpty()) {
287  vecPairCsdSum = inputData.vecPairCsd;
288  vecPairCsdImagAbsSum = inputData.vecPairCsdImagAbs;
289  } else {
290  for (int j = 0; j < vecPairCsdSum.size(); ++j) {
291  vecPairCsdSum[j].second += inputData.vecPairCsd.at(j).second;
292  vecPairCsdImagAbsSum[j].second += inputData.vecPairCsdImagAbs.at(j).second;
293  }
294  }
295 
296  mutex.unlock();
297 
298 // iTime = timer.elapsed();
299 // qWarning() << "WeightedPhaseLagIndex::compute timer - Add CSD to sum:" << iTime;
300 // timer.restart();
301  } else {
302  if (inputData.vecPairCsdImagAbs.isEmpty()) {
303  inputData.vecPairCsdImagAbs.clear();
304  for (i = 0; i < inputData.vecPairCsd.size(); ++i) {
305  inputData.vecPairCsdImagAbs.append(QPair<int,MatrixXd>(i,inputData.vecPairCsd.at(i).second.imag().cwiseAbs()));
306  }
307 
308  mutex.lock();
309 
310  if(vecPairCsdImagAbsSum.isEmpty()) {
311  vecPairCsdImagAbsSum = inputData.vecPairCsdImagAbs;
312  } else {
313  for (int j = 0; j < vecPairCsdImagAbsSum.size(); ++j) {
314  vecPairCsdImagAbsSum[j].second += inputData.vecPairCsdImagAbs.at(j).second;
315  }
316  }
317 
318  mutex.unlock();
319  }
320  }
321 
322  //Do not store data to save memory
323  if(!m_bStorageModeIsActive) {
324  inputData.vecPairCsd.clear();
325  inputData.vecPairCsdImagAbs.clear();
326  inputData.vecTapSpectra.clear();
327  }
328 }
329 
330 //=============================================================================================================
331 
333  Network& finalNetwork)
334 {
335  // Compute final WPLI and create Network
336  MatrixXd matDenom, matNom;
337  MatrixXd matWeight;
338  QSharedPointer<NetworkEdge> pEdge;
339  int j;
340 
341  for (int i = 0; i < connectivitySettings.getIntermediateSumData().vecPairCsdSum.size(); ++i) {
342  matDenom = connectivitySettings.getIntermediateSumData().vecPairCsdImagAbsSum.at(i).second;
343  matDenom = (matDenom.array() == 0.).select(INFINITY, matDenom);
344 
345  matNom = connectivitySettings.getIntermediateSumData().vecPairCsdSum.at(i).second.imag().cwiseAbs().cwiseQuotient(matDenom);
346 
347  for(j = i; j < matNom.rows(); ++j) {
348  matWeight = matNom.row(j).transpose();
349 
350  pEdge = QSharedPointer<NetworkEdge>(new NetworkEdge(i, j, matWeight));
351 
352  finalNetwork.getNodeAt(i)->append(pEdge);
353  finalNetwork.getNodeAt(j)->append(pEdge);
354  finalNetwork.append(pEdge);
355  }
356  }
357 }
358 
CONNECTIVITYLIB::ConnectivitySettings::IntermediateTrialData
Definition: connectivitysettings.h:98
spectral.h
Declaration of Spectral class.
CONNECTIVITYLIB::WeightedPhaseLagIndex::computeWPLI
static void computeWPLI(ConnectivitySettings &connectivitySettings, Network &finalNetwork)
Definition: weightedphaselagindex.cpp:332
CONNECTIVITYLIB::Network
This class holds information about a network, can compute a distance table and provide network metric...
Definition: network.h:88
CONNECTIVITYLIB::AbstractMetric
This class provides basic functionalities for all implemented metrics.
Definition: abstractmetric.h:77
CONNECTIVITYLIB::WeightedPhaseLagIndex::WeightedPhaseLagIndex
WeightedPhaseLagIndex()
Definition: weightedphaselagindex.cpp:79
CONNECTIVITYLIB::Network::append
void append(QSharedPointer< NetworkEdge > newEdge)
CONNECTIVITYLIB::Network::getNodeAt
QSharedPointer< NetworkNode > getNodeAt(int i)
Definition: network.cpp:163
weightedphaselagindex.h
WeightedPhaseLagIndex class declaration.
CONNECTIVITYLIB::Network::setSamplingFrequency
void setSamplingFrequency(float fSFreq)
Definition: network.cpp:492
CONNECTIVITYLIB::Network::setUsedFreqBins
void setUsedFreqBins(int iNumberFreqBins)
Definition: network.cpp:506
CONNECTIVITYLIB::NetworkEdge
This class holds an object to describe the edge of a network.
Definition: networkedge.h:79
CONNECTIVITYLIB::WeightedPhaseLagIndex::compute
static void compute(ConnectivitySettings::IntermediateTrialData &inputData, QVector< QPair< int, Eigen::MatrixXcd > > &vecPairCsdSum, QVector< QPair< int, Eigen::MatrixXd > > &vecPairCsdImagAbsSum, QMutex &mutex, int iNRows, int iNFreqs, int iNfft, const QPair< Eigen::MatrixXd, Eigen::VectorXd > &tapers)
Definition: weightedphaselagindex.cpp:190
CONNECTIVITYLIB::Network::setFFTSize
void setFFTSize(int iFFTSize)
Definition: network.cpp:513
CONNECTIVITYLIB::ConnectivitySettings
This class is a container for connectivity settings.
Definition: connectivitysettings.h:91
CONNECTIVITYLIB::NetworkNode::SPtr
QSharedPointer< NetworkNode > SPtr
Definition: networknode.h:85
CONNECTIVITYLIB::WeightedPhaseLagIndex::calculate
static Network calculate(ConnectivitySettings &connectivitySettings)
Definition: weightedphaselagindex.cpp:86
CONNECTIVITYLIB::NetworkNode
This class holds an object to describe the node of a network.
Definition: networknode.h:81