v2.0.0
Loading...
Searching...
No Matches
ml_onnx_model.cpp
Go to the documentation of this file.
1//=============================================================================================================
34
35//=============================================================================================================
36// INCLUDES
37//=============================================================================================================
38
39#include "ml_onnx_model.h"
40
41#ifdef MNE_USE_ONNXRUNTIME
42#include <onnxruntime_cxx_api.h>
43#endif
44
45//=============================================================================================================
46// QT INCLUDES
47//=============================================================================================================
48
49#include <QDebug>
50#include <QFileInfo>
51
52//=============================================================================================================
53// STL INCLUDES
54//=============================================================================================================
55
56#include <stdexcept>
57
58//=============================================================================================================
59// USED NAMESPACES
60//=============================================================================================================
61
62using namespace MLLIB;
63
64//=============================================================================================================
65// STATIC HELPERS
66//=============================================================================================================
67
68#ifdef MNE_USE_ONNXRUNTIME
69Ort::Env& MlOnnxModel::ortEnv()
70{
71 static Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "mne-cpp");
72 return env;
73}
74#endif
75
76//=============================================================================================================
77// DEFINE MEMBER METHODS
78//=============================================================================================================
79
83
84//=============================================================================================================
85
87{
88 // Explicit destructor needed so unique_ptr can see the complete ORT types.
89}
90
91//=============================================================================================================
92
94{
95#ifdef MNE_USE_ONNXRUNTIME
96 if (!m_session) {
97 throw std::runtime_error("MlOnnxModel::predict – No ONNX model loaded. Call load() first.");
98 }
99
100 // Create input Ort::Value — zero-copy from MlTensor's row-major float buffer
101 auto inputShape = input.shape();
102 Ort::Value inputTensor = Ort::Value::CreateTensor<float>(
103 *m_memoryInfo,
104 const_cast<float*>(input.data()),
105 static_cast<size_t>(input.size()),
106 inputShape.data(),
107 inputShape.size());
108
109 // Build C-string name arrays for Run()
110 std::vector<const char*> inputNamePtrs;
111 inputNamePtrs.reserve(m_inputNames.size());
112 for (const auto& n : m_inputNames)
113 inputNamePtrs.push_back(n.c_str());
114
115 std::vector<const char*> outputNamePtrs;
116 outputNamePtrs.reserve(m_outputNames.size());
117 for (const auto& n : m_outputNames)
118 outputNamePtrs.push_back(n.c_str());
119
120 // Run inference
121 Ort::RunOptions runOpts;
122 auto outputTensors = m_session->Run(
123 runOpts,
124 inputNamePtrs.data(), &inputTensor, inputNamePtrs.size(),
125 outputNamePtrs.data(), outputNamePtrs.size());
126
127 if (outputTensors.empty() || !outputTensors[0].IsTensor()) {
128 throw std::runtime_error("MlOnnxModel::predict – Model produced no valid output tensor.");
129 }
130
131 // Extract output shape and data — copy into an owning MlTensor
132 auto outputInfo = outputTensors[0].GetTensorTypeAndShapeInfo();
133 std::vector<int64_t> outputShape = outputInfo.GetShape();
134 const float* outputData = outputTensors[0].GetTensorData<float>();
135
136 return MlTensor(outputData, std::move(outputShape));
137#else
138 Q_UNUSED(input);
139 throw std::runtime_error("ONNX Runtime not available. Build with -DUSE_ONNXRUNTIME=ON");
140#endif
141}
142
143//=============================================================================================================
144
145bool MlOnnxModel::save(const QString& path) const
146{
147 Q_UNUSED(path);
148 qWarning() << "MlOnnxModel::save – ONNX models cannot be saved from this interface.";
149 return false;
150}
151
152//=============================================================================================================
153
154bool MlOnnxModel::load(const QString& path)
155{
156#ifdef MNE_USE_ONNXRUNTIME
157 m_modelPath = path;
158
159 if (!QFileInfo::exists(path)) {
160 qWarning() << "MlOnnxModel::load – File does not exist:" << path;
161 return false;
162 }
163
164 try {
165 // Session options — enable all graph optimizations, single intra-op thread for determinism
166 Ort::SessionOptions sessionOpts;
167 sessionOpts.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
168 sessionOpts.SetIntraOpNumThreads(1);
169 sessionOpts.DisableMemPattern(); // reduces peak memory for small models
170
171 // Create session from the ONNX file
172 std::string modelPathStd = path.toStdString();
173 m_session = std::make_unique<Ort::Session>(ortEnv(), modelPathStd.c_str(), sessionOpts);
174
175 // CPU memory info (reused for every predict call)
176 m_memoryInfo = std::make_unique<Ort::MemoryInfo>(
177 Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault));
178
179 // Cache input names and shapes
180 Ort::AllocatorWithDefaultOptions allocator;
181 size_t numInputs = m_session->GetInputCount();
182 m_inputNames.clear();
183 m_inputShapes.clear();
184 m_inputNames.reserve(numInputs);
185 m_inputShapes.reserve(numInputs);
186
187 for (size_t i = 0; i < numInputs; ++i) {
188 auto namePtr = m_session->GetInputNameAllocated(i, allocator);
189 m_inputNames.emplace_back(namePtr.get());
190
191 auto typeInfo = m_session->GetInputTypeInfo(i);
192 auto shape = typeInfo.GetTensorTypeAndShapeInfo().GetShape();
193 // Replace dynamic dimensions (-1) with 1 for logging; actual shape comes from input tensor
194 m_inputShapes.push_back(std::move(shape));
195 }
196
197 // Cache output names
198 size_t numOutputs = m_session->GetOutputCount();
199 m_outputNames.clear();
200 m_outputNames.reserve(numOutputs);
201 for (size_t i = 0; i < numOutputs; ++i) {
202 auto namePtr = m_session->GetOutputNameAllocated(i, allocator);
203 m_outputNames.emplace_back(namePtr.get());
204 }
205
206 qDebug() << "MlOnnxModel::load – Session created for" << path
207 << "(" << numInputs << "inputs," << numOutputs << "outputs)";
208 return true;
209
210 } catch (const Ort::Exception& e) {
211 qWarning() << "MlOnnxModel::load – ORT error:" << e.what();
212 m_session.reset();
213 return false;
214 }
215#else
216 m_modelPath = path;
217 qWarning() << "MlOnnxModel::load – Path stored but no ONNX Runtime session created (build with -DUSE_ONNXRUNTIME=ON).";
218 return false;
219#endif
220}
221
222//=============================================================================================================
223
225{
226#ifdef MNE_USE_ONNXRUNTIME
227 return m_session != nullptr;
228#else
229 return false;
230#endif
231}
232
233//=============================================================================================================
234
236{
237 return QStringLiteral("onnx");
238}
239
240//=============================================================================================================
241
243{
244 return m_taskType;
245}
MlOnnxModel class declaration.
Machine learning (models, pipelines, ONNX Runtime integration).
MlTaskType
Definition ml_types.h:60
MlTaskType taskType() const override
bool load(const QString &path) override
QString modelType() const override
MlTensor predict(const MlTensor &input) const override
bool save(const QString &path) const override
N-dimensional tensor with contiguous row-major (C-order) float32 storage.
Definition ml_tensor.h:79
int64_t size() const
const std::vector< int64_t > & shape() const