v2.0.0
Loading...
Searching...
No Matches
ml_tensor.h
Go to the documentation of this file.
1//=============================================================================================================
29
30#ifndef ML_TENSOR_H
31#define ML_TENSOR_H
32
33//=============================================================================================================
34// INCLUDES
35//=============================================================================================================
36
37#include "ml_global.h"
38
39//=============================================================================================================
40// EIGEN INCLUDES
41//=============================================================================================================
42
43#include <Eigen/Core>
44
45//=============================================================================================================
46// STL INCLUDES
47//=============================================================================================================
48
49#include <cassert>
50#include <cstdint>
51#include <memory>
52#include <vector>
53
54//=============================================================================================================
55// DEFINE NAMESPACE MLLIB
56//=============================================================================================================
57
58namespace MLLIB{
59
60//=============================================================================================================
73{
74public:
75 // --- type aliases used in the public API --------------------------------
76 using RowMajorMatrixXf = Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
77 using RowMajorMatrixMap = Eigen::Map<RowMajorMatrixXf>;
78 using ConstRowMajorMatrixMap = Eigen::Map<const RowMajorMatrixXf>;
79
80 //=========================================================================================================
84 MlTensor();
85
86 //=========================================================================================================
94 MlTensor(std::vector<float>&& data, std::vector<int64_t> shape);
95
96 //=========================================================================================================
103 MlTensor(const float* data, std::vector<int64_t> shape);
104
105 //=========================================================================================================
113 explicit MlTensor(const Eigen::MatrixXf& mat);
114
115 //=========================================================================================================
123 explicit MlTensor(const Eigen::MatrixXd& mat);
124
125 //=========================================================================================================
135 static MlTensor view(float* data, std::vector<int64_t> shape);
136
137 //=========================================================================================================
146 static MlTensor fromBuffer(const float* data, int rows, int cols);
147
148 // --- shape access -------------------------------------------------------
149
150 //=========================================================================================================
154 int ndim() const;
155
156 //=========================================================================================================
160 int64_t size() const;
161
162 //=========================================================================================================
166 const std::vector<int64_t>& shape() const;
167
168 //=========================================================================================================
173 int64_t shape(int dim) const;
174
175 //=========================================================================================================
180 int rows() const;
181
182 //=========================================================================================================
187 int cols() const;
188
189 // --- raw data access (zero-copy) ----------------------------------------
190
191 //=========================================================================================================
195 float* data();
196
197 //=========================================================================================================
201 const float* data() const;
202
203 // --- Eigen Map accessors (zero-copy, 2-D) -------------------------------
204
205 //=========================================================================================================
213
214 //=========================================================================================================
222
223 // --- Eigen copy helpers (produce column-major copies) -------------------
224
225 //=========================================================================================================
229 Eigen::MatrixXf toMatrixXf() const;
230
231 //=========================================================================================================
235 Eigen::MatrixXd toMatrixXd() const;
236
237 // --- reshape / query ----------------------------------------------------
238
239 //=========================================================================================================
249 MlTensor reshape(std::vector<int64_t> newShape) const;
250
251 //=========================================================================================================
255 bool isView() const;
256
257 //=========================================================================================================
261 bool empty() const;
262
263private:
264 static int64_t computeSize(const std::vector<int64_t>& shape);
265
266 std::shared_ptr<std::vector<float>> m_storage;
267 float* m_data = nullptr;
268 std::vector<int64_t> m_shape;
269 int64_t m_size = 0;
270};
271
272} // namespace MLLIB
273
274#endif // ML_TENSOR_H
Export/import macros, build-stamp accessors and namespace anchor for the MLLIB machine-learning libra...
#define MLSHARED_EXPORT
Definition ml_global.h:48
Tensors, model abstraction, ONNX Runtime inference and Python training drivers used across mne-cpp.
int cols() const
Eigen::MatrixXf toMatrixXf() const
Eigen::Map< RowMajorMatrixXf > RowMajorMatrixMap
Definition ml_tensor.h:77
Eigen::Matrix< float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor > RowMajorMatrixXf
Definition ml_tensor.h:76
static MlTensor fromBuffer(const float *data, int rows, int cols)
Eigen::MatrixXd toMatrixXd() const
static MlTensor view(float *data, std::vector< int64_t > shape)
MlTensor(const Eigen::MatrixXd &mat)
int ndim() const
RowMajorMatrixMap matrix()
int64_t size() const
MlTensor(const Eigen::MatrixXf &mat)
int rows() const
MlTensor reshape(std::vector< int64_t > newShape) const
bool empty() const
const std::vector< int64_t > & shape() const
bool isView() const
Eigen::Map< const RowMajorMatrixXf > ConstRowMajorMatrixMap
Definition ml_tensor.h:78