Fix adult numeric features mistake
This commit is contained in:
@@ -531,6 +531,11 @@ TEST_CASE("Test Dataset Loading", "[Datasets]")
|
|||||||
}
|
}
|
||||||
std::cout << "| " << dataset.yt[sample].item<int>() << std::endl;
|
std::cout << "| " << dataset.yt[sample].item<int>() << std::endl;
|
||||||
}
|
}
|
||||||
|
auto features = dataset.features;
|
||||||
|
std::cout << "States:" << std::endl;
|
||||||
|
for (int i = 0; i < 14; i++) {
|
||||||
|
std::cout << i << " has " << dataset.states.at(features[i]).size() << " states." << std::endl;
|
||||||
|
}
|
||||||
dataset = RawDatasets("adult", false);
|
dataset = RawDatasets("adult", false);
|
||||||
std::cout << "Dataset adult raw " << std::endl;
|
std::cout << "Dataset adult raw " << std::endl;
|
||||||
for (int sample = 0; sample < max_sample; sample++) {
|
for (int sample = 0; sample < max_sample; sample++) {
|
||||||
@@ -539,4 +544,10 @@ TEST_CASE("Test Dataset Loading", "[Datasets]")
|
|||||||
}
|
}
|
||||||
std::cout << "| " << dataset.yt[sample].item<int>() << std::endl;
|
std::cout << "| " << dataset.yt[sample].item<int>() << std::endl;
|
||||||
}
|
}
|
||||||
|
std::cout << "States:" << std::endl;
|
||||||
|
for (int i = 0; i < 14; i++) {
|
||||||
|
std::cout << i << " has " << dataset.states.at(features[i]).size() << " states." << std::endl;
|
||||||
|
}
|
||||||
|
auto clf = bayesnet::TANLd();
|
||||||
|
clf.fit(dataset.Xt, dataset.yt, dataset.features, dataset.className, dataset.states, dataset.smoothing);
|
||||||
}
|
}
|
||||||
|
@@ -213,6 +213,8 @@ void RawDatasets::loadDataset(const std::string& name, bool class_last)
|
|||||||
if (!is_numeric.at(i)) {
|
if (!is_numeric.at(i)) {
|
||||||
states[features[i]] = std::vector<int>(maxValues[features[i]]);
|
states[features[i]] = std::vector<int>(maxValues[features[i]]);
|
||||||
iota(begin(states.at(features[i])), end(states.at(features[i])), 0);
|
iota(begin(states.at(features[i])), end(states.at(features[i])), 0);
|
||||||
|
} else {
|
||||||
|
states[features[i]] = std::vector<int>();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
yt = torch::tensor(yv, torch::kInt32);
|
yt = torch::tensor(yv, torch::kInt32);
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
adult;class;[0,2,4,11,12,13]
|
adult;class;[0,2,4,10,11,12]
|
||||||
balance-scale;class; all
|
balance-scale;class; all
|
||||||
breast-w;Class; all
|
breast-w;Class; all
|
||||||
diabetes;class; all
|
diabetes;class; all
|
||||||
|
Reference in New Issue
Block a user