v2.0.0
Loading...
Searching...
No Matches
ml_linear_model.cpp
Go to the documentation of this file.
1//=============================================================================================================
34
35//=============================================================================================================
36// INCLUDES
37//=============================================================================================================
38
39#include "ml_linear_model.h"
40
41//=============================================================================================================
42// EIGEN INCLUDES
43//=============================================================================================================
44
45#include <Eigen/Dense>
46
47//=============================================================================================================
48// QT INCLUDES
49//=============================================================================================================
50
51#include <QFile>
52#include <QJsonDocument>
53#include <QJsonObject>
54#include <QJsonArray>
55#include <QDebug>
56
57//=============================================================================================================
58// STL INCLUDES
59//=============================================================================================================
60
61#include <stdexcept>
62#include <cmath>
63
64//=============================================================================================================
65// USED NAMESPACES
66//=============================================================================================================
67
68using namespace MLLIB;
69using namespace Eigen;
70
71//=============================================================================================================
72// DEFINE MEMBER METHODS
73//=============================================================================================================
74
75MlLinearModel::MlLinearModel(MlTaskType type, double regularization)
76: m_taskType(type)
77, m_regularization(regularization)
78{
79}
80
81//=============================================================================================================
82
84{
85 if (!m_trained) {
86 throw std::runtime_error("MlLinearModel::predict – model has not been trained.");
87 }
88
89 auto X = input.matrix();
90 MatrixXf result = (X * m_weights).rowwise() + m_bias.transpose();
91
92 if (m_taskType == MlTaskType::Classification) {
93 // Apply sigmoid: p = 1 / (1 + exp(-z))
94 result = (1.0f + (-result.array()).exp()).inverse().matrix();
95 }
96
97 return MlTensor(result);
98}
99
100//=============================================================================================================
101
102bool MlLinearModel::save(const QString& path) const
103{
104 if (!m_trained) {
105 qWarning() << "MlLinearModel::save – model not trained; nothing to save.";
106 return false;
107 }
108
109 QJsonObject root;
110 root[QStringLiteral("model_type")] = QStringLiteral("linear");
111
112 switch (m_taskType) {
113 case MlTaskType::Regression: root[QStringLiteral("task_type")] = QStringLiteral("regression"); break;
114 case MlTaskType::Classification: root[QStringLiteral("task_type")] = QStringLiteral("classification"); break;
115 case MlTaskType::FeatureExtraction: root[QStringLiteral("task_type")] = QStringLiteral("feature_extraction"); break;
116 }
117
118 root[QStringLiteral("regularization")] = m_regularization;
119 root[QStringLiteral("n_features")] = static_cast<int>(m_weights.rows());
120 root[QStringLiteral("n_outputs")] = static_cast<int>(m_weights.cols());
121
122 // Weights: array of arrays (row-major)
123 QJsonArray wArr;
124 for (int r = 0; r < m_weights.rows(); ++r) {
125 QJsonArray row;
126 for (int c = 0; c < m_weights.cols(); ++c) {
127 row.append(static_cast<double>(m_weights(r, c)));
128 }
129 wArr.append(row);
130 }
131 root[QStringLiteral("weights")] = wArr;
132
133 // Bias
134 QJsonArray bArr;
135 for (int i = 0; i < m_bias.size(); ++i) {
136 bArr.append(static_cast<double>(m_bias(i)));
137 }
138 root[QStringLiteral("bias")] = bArr;
139
140 QFile file(path);
141 if (!file.open(QIODevice::WriteOnly)) {
142 qWarning() << "MlLinearModel::save – cannot open file" << path;
143 return false;
144 }
145 file.write(QJsonDocument(root).toJson(QJsonDocument::Indented));
146 return true;
147}
148
149//=============================================================================================================
150
151bool MlLinearModel::load(const QString& path)
152{
153 QFile file(path);
154 if (!file.open(QIODevice::ReadOnly)) {
155 qWarning() << "MlLinearModel::load – cannot open file" << path;
156 return false;
157 }
158
159 QJsonDocument doc = QJsonDocument::fromJson(file.readAll());
160 if (doc.isNull()) {
161 qWarning() << "MlLinearModel::load – invalid JSON in" << path;
162 return false;
163 }
164
165 QJsonObject root = doc.object();
166
167 // Task type
168 QString tt = root[QStringLiteral("task_type")].toString();
169 if (tt == QStringLiteral("regression")) m_taskType = MlTaskType::Regression;
170 else if (tt == QStringLiteral("classification")) m_taskType = MlTaskType::Classification;
171 else if (tt == QStringLiteral("feature_extraction")) m_taskType = MlTaskType::FeatureExtraction;
172
173 m_regularization = root[QStringLiteral("regularization")].toDouble(1.0);
174
175 int nFeatures = root[QStringLiteral("n_features")].toInt();
176 int nOutputs = root[QStringLiteral("n_outputs")].toInt();
177
178 // Weights
179 QJsonArray wArr = root[QStringLiteral("weights")].toArray();
180 m_weights.resize(nFeatures, nOutputs);
181 for (int r = 0; r < nFeatures; ++r) {
182 QJsonArray row = wArr[r].toArray();
183 for (int c = 0; c < nOutputs; ++c) {
184 m_weights(r, c) = static_cast<float>(row[c].toDouble());
185 }
186 }
187
188 // Bias
189 QJsonArray bArr = root[QStringLiteral("bias")].toArray();
190 m_bias.resize(nOutputs);
191 for (int i = 0; i < nOutputs; ++i) {
192 m_bias(i) = static_cast<float>(bArr[i].toDouble());
193 }
194
195 m_trained = true;
196 return true;
197}
198
199//=============================================================================================================
200
202{
203 return QStringLiteral("linear");
204}
205
206//=============================================================================================================
207
209{
210 return m_taskType;
211}
212
213//=============================================================================================================
214
215const MatrixXf& MlLinearModel::weights() const
216{
217 return m_weights;
218}
219
220//=============================================================================================================
221
222const VectorXf& MlLinearModel::bias() const
223{
224 return m_bias;
225}
MlLinearModel class declaration.
constexpr int X
Machine learning (models, pipelines, ONNX Runtime integration).
MlTaskType
Definition ml_types.h:60
MlTensor predict(const MlTensor &input) const override
bool save(const QString &path) const override
MlLinearModel(MlTaskType type=MlTaskType::Regression, double regularization=1.0)
bool load(const QString &path) override
const Eigen::MatrixXf & weights() const
MlTaskType taskType() const override
const Eigen::VectorXf & bias() const
QString modelType() const override
N-dimensional tensor with contiguous row-major (C-order) float32 storage.
Definition ml_tensor.h:79
RowMajorMatrixMap matrix()