Add torch methods to discretize

Add fit_transform methods
This commit is contained in:
2024-06-07 23:54:42 +02:00
parent 633aa52849
commit e205668906
14 changed files with 183 additions and 87 deletions

View File

@@ -3,6 +3,7 @@
#include <string>
#include <algorithm>
#include <torch/torch.h>
#include "typesFImdlp.h"
namespace mdlp {
@@ -10,19 +11,14 @@ namespace mdlp {
public:
Discretizer() = default;
virtual ~Discretizer() = default;
virtual void fit(samples_t& X_, labels_t& y_) = 0;
inline cutPoints_t getCutPoints() const { return cutPoints; };
labels_t& transform(const samples_t& data)
{
discretizedData.clear();
discretizedData.reserve(data.size());
for (const precision_t& item : data) {
auto upper = std::upper_bound(cutPoints.begin(), cutPoints.end(), item);
discretizedData.push_back(upper - cutPoints.begin());
}
return discretizedData;
};
static inline std::string version() { return "1.2.0"; };
virtual void fit(samples_t& X_, labels_t& y_) = 0;
labels_t& transform(const samples_t& data);
labels_t& fit_transform(samples_t& X_, labels_t& y_);
void fit_t(torch::Tensor& X_, torch::Tensor& y_);
torch::Tensor transform_t(torch::Tensor& X_);
torch::Tensor fit_transform_t(torch::Tensor& X_, torch::Tensor& y_);
static inline std::string version() { return "1.2.1"; };
protected:
labels_t discretizedData = labels_t();
cutPoints_t cutPoints;