mirror of
https://github.com/rmontanana/mdlp.git
synced 2025-08-16 07:55:58 +00:00
Add torch methods to discretize
Add fit_transform methods
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <torch/torch.h>
|
||||
#include "typesFImdlp.h"
|
||||
|
||||
namespace mdlp {
|
||||
@@ -10,19 +11,14 @@ namespace mdlp {
|
||||
public:
|
||||
Discretizer() = default;
|
||||
virtual ~Discretizer() = default;
|
||||
virtual void fit(samples_t& X_, labels_t& y_) = 0;
|
||||
inline cutPoints_t getCutPoints() const { return cutPoints; };
|
||||
labels_t& transform(const samples_t& data)
|
||||
{
|
||||
discretizedData.clear();
|
||||
discretizedData.reserve(data.size());
|
||||
for (const precision_t& item : data) {
|
||||
auto upper = std::upper_bound(cutPoints.begin(), cutPoints.end(), item);
|
||||
discretizedData.push_back(upper - cutPoints.begin());
|
||||
}
|
||||
return discretizedData;
|
||||
};
|
||||
static inline std::string version() { return "1.2.0"; };
|
||||
virtual void fit(samples_t& X_, labels_t& y_) = 0;
|
||||
labels_t& transform(const samples_t& data);
|
||||
labels_t& fit_transform(samples_t& X_, labels_t& y_);
|
||||
void fit_t(torch::Tensor& X_, torch::Tensor& y_);
|
||||
torch::Tensor transform_t(torch::Tensor& X_);
|
||||
torch::Tensor fit_transform_t(torch::Tensor& X_, torch::Tensor& y_);
|
||||
static inline std::string version() { return "1.2.1"; };
|
||||
protected:
|
||||
labels_t discretizedData = labels_t();
|
||||
cutPoints_t cutPoints;
|
||||
|
Reference in New Issue
Block a user