fit discretizer only with train data

This commit is contained in:
2024-06-09 00:50:55 +02:00
parent 361c51d864
commit 643633e6dd
9 changed files with 38 additions and 44 deletions

View File

@@ -7,7 +7,7 @@ namespace platform {
path(dataset.path), name(dataset.name), className(dataset.className), n_samples(dataset.n_samples),
n_features(dataset.n_features), numericFeatures(dataset.numericFeatures), features(dataset.features),
states(dataset.states), loaded(dataset.loaded), discretize(dataset.discretize), X(dataset.X), y(dataset.y),
X_train(dataset.X_train), X_test(dataset.X_test), Xv(dataset.Xv), Xd(dataset.Xd), yv(dataset.yv),
X_train(dataset.X_train), X_test(dataset.X_test), Xv(dataset.Xv), yv(dataset.yv),
fileType(dataset.fileType)
{
}
@@ -46,9 +46,6 @@ namespace platform {
int Dataset::getNClasses() const
{
if (loaded) {
if (discretize) {
return states.at(className).size();
}
return *std::max_element(yv.begin(), yv.end()) + 1;
} else {
throw std::invalid_argument(message_dataset_not_loaded);
@@ -91,14 +88,6 @@ namespace platform {
throw std::invalid_argument(message_dataset_not_loaded);
}
}
pair<std::vector<std::vector<int>>&, std::vector<int>&> Dataset::getVectorsDiscretized()
{
if (loaded) {
return { Xd, yv };
} else {
throw std::invalid_argument(message_dataset_not_loaded);
}
}
pair<torch::Tensor&, torch::Tensor&> Dataset::getTensors()
{
if (loaded) {
@@ -140,11 +129,13 @@ namespace platform {
void Dataset::computeStates()
{
for (int i = 0; i < features.size(); ++i) {
states[features[i]] = std::vector<int>(*max_element(Xd[i].begin(), Xd[i].end()) + 1);
auto [max_value, idx] = torch::max(X_train.index({ i, "..." }), 0);
states[features[i]] = std::vector<int>(max_value.item<int>() + 1);
auto item = states.at(features[i]);
iota(begin(item), end(item), 0);
}
states[className] = std::vector<int>(*max_element(yv.begin(), yv.end()) + 1);
auto [max_value, idx] = torch::max(y_train, 0);
states[className] = std::vector<int>(max_value.item<int>() + 1);
iota(begin(states.at(className)), end(states.at(className)), 0);
}
void Dataset::load_arff()
@@ -245,17 +236,6 @@ namespace platform {
y = torch::tensor(yv, torch::kInt32);
loaded = true;
}
std::vector<mdlp::labels_t> Dataset::discretizeDataset(std::vector<mdlp::samples_t>& X, mdlp::labels_t& y)
{
std::vector<mdlp::labels_t> Xd;
auto fimdlp = mdlp::CPPFImdlp();
for (int i = 0; i < X.size(); i++) {
fimdlp.fit(X[i], y);
mdlp::labels_t& xd = fimdlp.transform(X[i]);
Xd.push_back(xd);
}
return Xd;
}
std::tuple<torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&> Dataset::getTrainTestTensors(std::vector<int>& train, std::vector<int>& test)
{
if (!loaded) {
@@ -273,15 +253,14 @@ namespace platform {
auto discretizer = Discretization::instance()->create(discretizer_algorithm);
auto X_train_d = torch::zeros({ n_features, samples_train }, torch::kInt32);
auto X_test_d = torch::zeros({ n_features, samples_test }, torch::kInt32);
for (int feature = 0; feature < n_features; ++feature) {
for (auto feature = 0; feature < n_features; ++feature) {
if (numericFeatures[feature]) {
auto X_train_feature = X_train.index({ feature, "..." }).to(torch::kFloat32);
auto X_test_feature = X_test.index({ feature, "..." }).to(torch::kFloat32);
discretizer->fit(X_train_feature, y_train);
auto X_train_feature_d = discretizer->transform(X_train_feature);
auto X_test_feature_d = discretizer->transform(X_test_feature);
X_train_d.index_put_({ feature, "..." }, X_train_feature_d.to(torch::kInt32));
X_test_d.index_put_({ feature, "..." }, X_test_feature_d.to(torch::kInt32));
auto feature_train = X_train.index({ feature, "..." });
auto feature_test = X_test.index({ feature, "..." });
auto feature_train_disc = discretizer->fit_transform_t(feature_train, y_train);
auto feature_test_disc = discretizer->transform_t(feature_test);
X_train_d.index_put_({ feature, "..." }, feature_train_disc);
X_test_d.index_put_({ feature, "..." }, feature_test_disc);
} else {
X_train_d.index_put_({ feature, "..." }, X_train.index({ feature, "..." }).to(torch::kInt32));
X_test_d.index_put_({ feature, "..." }, X_test.index({ feature, "..." }).to(torch::kInt32));
@@ -289,7 +268,12 @@ namespace platform {
}
X_train = X_train_d;
X_test = X_test_d;
assert(X_train.dtype() == torch::kInt32);
assert(X_test.dtype() == torch::kInt32);
computeStates();
}
assert(y_train.dtype() == torch::kInt32);
assert(y_test.dtype() == torch::kInt32);
return { X_train, X_test, y_train, y_test };
}
}