v2.0.0
Loading...
Searching...
No Matches
ml_onnx_model.cpp
Go to the documentation of this file.
1//=============================================================================================================
31
32//=============================================================================================================
33// INCLUDES
34//=============================================================================================================
35
36#include "ml_onnx_model.h"
37
38#ifdef MNE_USE_ONNXRUNTIME
39#include <onnxruntime_cxx_api.h>
40#endif
41
42//=============================================================================================================
43// QT INCLUDES
44//=============================================================================================================
45
46#include <QDebug>
47#include <QFileInfo>
48
49//=============================================================================================================
50// STL INCLUDES
51//=============================================================================================================
52
53#include <stdexcept>
54
55//=============================================================================================================
56// USED NAMESPACES
57//=============================================================================================================
58
59using namespace MLLIB;
60
61//=============================================================================================================
62// STATIC HELPERS
63//=============================================================================================================
64
65#ifdef MNE_USE_ONNXRUNTIME
66Ort::Env& MlOnnxModel::ortEnv()
67{
68 static Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "mne-cpp");
69 return env;
70}
71#endif
72
73//=============================================================================================================
74// DEFINE MEMBER METHODS
75//=============================================================================================================
76
80
81//=============================================================================================================
82
84{
85 // Explicit destructor needed so unique_ptr can see the complete ORT types.
86}
87
88//=============================================================================================================
89
91{
92#ifdef MNE_USE_ONNXRUNTIME
93 if (!m_session) {
94 throw std::runtime_error("MlOnnxModel::predict – No ONNX model loaded. Call load() first.");
95 }
96
97 // Create input Ort::Value — zero-copy from MlTensor's row-major float buffer
98 auto inputShape = input.shape();
99 Ort::Value inputTensor = Ort::Value::CreateTensor<float>(
100 *m_memoryInfo,
101 const_cast<float*>(input.data()),
102 static_cast<size_t>(input.size()),
103 inputShape.data(),
104 inputShape.size());
105
106 // Build C-string name arrays for Run()
107 std::vector<const char*> inputNamePtrs;
108 inputNamePtrs.reserve(m_inputNames.size());
109 for (const auto& n : m_inputNames)
110 inputNamePtrs.push_back(n.c_str());
111
112 std::vector<const char*> outputNamePtrs;
113 outputNamePtrs.reserve(m_outputNames.size());
114 for (const auto& n : m_outputNames)
115 outputNamePtrs.push_back(n.c_str());
116
117 // Run inference
118 Ort::RunOptions runOpts;
119 auto outputTensors = m_session->Run(
120 runOpts,
121 inputNamePtrs.data(), &inputTensor, inputNamePtrs.size(),
122 outputNamePtrs.data(), outputNamePtrs.size());
123
124 if (outputTensors.empty() || !outputTensors[0].IsTensor()) {
125 throw std::runtime_error("MlOnnxModel::predict – Model produced no valid output tensor.");
126 }
127
128 // Extract output shape and data — copy into an owning MlTensor
129 auto outputInfo = outputTensors[0].GetTensorTypeAndShapeInfo();
130 std::vector<int64_t> outputShape = outputInfo.GetShape();
131 const float* outputData = outputTensors[0].GetTensorData<float>();
132
133 return MlTensor(outputData, std::move(outputShape));
134#else
135 Q_UNUSED(input);
136 throw std::runtime_error("ONNX Runtime not available. Build with -DUSE_ONNXRUNTIME=ON");
137#endif
138}
139
140//=============================================================================================================
141
142bool MlOnnxModel::save(const QString& path) const
143{
144 Q_UNUSED(path);
145 qWarning() << "MlOnnxModel::save – ONNX models cannot be saved from this interface.";
146 return false;
147}
148
149//=============================================================================================================
150
151bool MlOnnxModel::load(const QString& path)
152{
153#ifdef MNE_USE_ONNXRUNTIME
154 m_modelPath = path;
155
156 if (!QFileInfo::exists(path)) {
157 qWarning() << "MlOnnxModel::load – File does not exist:" << path;
158 return false;
159 }
160
161 try {
162 // Session options — enable all graph optimizations, single intra-op thread for determinism
163 Ort::SessionOptions sessionOpts;
164 sessionOpts.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
165 sessionOpts.SetIntraOpNumThreads(1);
166 sessionOpts.DisableMemPattern(); // reduces peak memory for small models
167
168 // Create session from the ONNX file
169 std::string modelPathStd = path.toStdString();
170 m_session = std::make_unique<Ort::Session>(ortEnv(), modelPathStd.c_str(), sessionOpts);
171
172 // CPU memory info (reused for every predict call)
173 m_memoryInfo = std::make_unique<Ort::MemoryInfo>(
174 Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault));
175
176 // Cache input names and shapes
177 Ort::AllocatorWithDefaultOptions allocator;
178 size_t numInputs = m_session->GetInputCount();
179 m_inputNames.clear();
180 m_inputShapes.clear();
181 m_inputNames.reserve(numInputs);
182 m_inputShapes.reserve(numInputs);
183
184 for (size_t i = 0; i < numInputs; ++i) {
185 auto namePtr = m_session->GetInputNameAllocated(i, allocator);
186 m_inputNames.emplace_back(namePtr.get());
187
188 auto typeInfo = m_session->GetInputTypeInfo(i);
189 auto shape = typeInfo.GetTensorTypeAndShapeInfo().GetShape();
190 // Replace dynamic dimensions (-1) with 1 for logging; actual shape comes from input tensor
191 m_inputShapes.push_back(std::move(shape));
192 }
193
194 // Cache output names
195 size_t numOutputs = m_session->GetOutputCount();
196 m_outputNames.clear();
197 m_outputNames.reserve(numOutputs);
198 for (size_t i = 0; i < numOutputs; ++i) {
199 auto namePtr = m_session->GetOutputNameAllocated(i, allocator);
200 m_outputNames.emplace_back(namePtr.get());
201 }
202
203 qDebug() << "MlOnnxModel::load – Session created for" << path
204 << "(" << numInputs << "inputs," << numOutputs << "outputs)";
205 return true;
206
207 } catch (const Ort::Exception& e) {
208 qWarning() << "MlOnnxModel::load – ORT error:" << e.what();
209 m_session.reset();
210 return false;
211 }
212#else
213 m_modelPath = path;
214 qWarning() << "MlOnnxModel::load – Path stored but no ONNX Runtime session created (build with -DUSE_ONNXRUNTIME=ON).";
215 return false;
216#endif
217}
218
219//=============================================================================================================
220
222{
223#ifdef MNE_USE_ONNXRUNTIME
224 return m_session != nullptr;
225#else
226 return false;
227#endif
228}
229
230//=============================================================================================================
231
233{
234 return QStringLiteral("onnx");
235}
236
237//=============================================================================================================
238
240{
241 return m_taskType;
242}
ONNX Runtime backed MLLIB::MlModel implementation for loading and evaluating .onnx graphs.
Tensors, model abstraction, ONNX Runtime inference and Python training drivers used across mne-cpp.
MlTaskType
Definition ml_types.h:51
MlTaskType taskType() const override
bool load(const QString &path) override
QString modelType() const override
MlTensor predict(const MlTensor &input) const override
bool save(const QString &path) const override
N-dimensional row-major float32 tensor with shared-buffer storage, Eigen Map accessors and a non-owni...
Definition ml_tensor.h:73
int64_t size() const
const std::vector< int64_t > & shape() const