Complete nxm
This commit is contained in:
@@ -95,6 +95,7 @@ int main(int argc, char** argv)
|
||||
}
|
||||
);
|
||||
program.add_argument("--discretize").help("Discretize input dataset").default_value(false).implicit_value(true);
|
||||
program.add_argument("--dumpcpt").help("Dump CPT Tables").default_value(false).implicit_value(true);
|
||||
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value(false).implicit_value(true);
|
||||
program.add_argument("--tensors").help("Use tensors to store samples").default_value(false).implicit_value(true);
|
||||
program.add_argument("-f", "--folds").help("Number of folds").default_value(5).scan<'i', int>().action([](const string& value) {
|
||||
@@ -112,7 +113,7 @@ int main(int argc, char** argv)
|
||||
throw runtime_error("Number of folds must be an integer");
|
||||
}});
|
||||
program.add_argument("-s", "--seed").help("Random seed").default_value(-1).scan<'i', int>();
|
||||
bool class_last, stratified, tensors;
|
||||
bool class_last, stratified, tensors, dump_cpt;
|
||||
string model_name, file_name, path, complete_file_name;
|
||||
int nFolds, seed;
|
||||
try {
|
||||
@@ -125,6 +126,7 @@ int main(int argc, char** argv)
|
||||
tensors = program.get<bool>("tensors");
|
||||
nFolds = program.get<int>("folds");
|
||||
seed = program.get<int>("seed");
|
||||
dump_cpt = program.get<bool>("dumpcpt");
|
||||
class_last = datasets[file_name];
|
||||
if (!file_exists(complete_file_name)) {
|
||||
throw runtime_error("Data File " + path + file_name + ".arff" + " does not exist");
|
||||
@@ -158,21 +160,25 @@ int main(int argc, char** argv)
|
||||
states[feature] = vector<int>(maxes[feature]);
|
||||
}
|
||||
states[className] = vector<int>(maxes[className]);
|
||||
|
||||
auto clf = platform::Models::instance()->create(model_name);
|
||||
clf->fit(Xd, y, features, className, states);
|
||||
auto score = clf->score(Xd, y);
|
||||
if (dump_cpt) {
|
||||
cout << "--- CPT Tables ---" << endl;
|
||||
clf->dump_cpt();
|
||||
}
|
||||
auto lines = clf->show();
|
||||
auto graph = clf->graph();
|
||||
for (auto line : lines) {
|
||||
cout << line << endl;
|
||||
}
|
||||
cout << "--- Topological Order ---" << endl;
|
||||
for (auto name : clf->topological_order()) {
|
||||
auto order = clf->topological_order();
|
||||
for (auto name : order) {
|
||||
cout << name << ", ";
|
||||
}
|
||||
cout << "end." << endl;
|
||||
auto score = clf->score(Xd, y);
|
||||
cout << "Score: " << score << endl;
|
||||
auto graph = clf->graph();
|
||||
auto dot_file = model_name + "_" + file_name;
|
||||
ofstream file(dot_file + ".dot");
|
||||
file << graph;
|
||||
@@ -211,9 +217,14 @@ int main(int argc, char** argv)
|
||||
auto [Xtrain, ytrain] = extract_indices(train, Xd, y);
|
||||
auto [Xtest, ytest] = extract_indices(test, Xd, y);
|
||||
clf->fit(Xtrain, ytrain, features, className, states);
|
||||
|
||||
score_train = clf->score(Xtrain, ytrain);
|
||||
score_test = clf->score(Xtest, ytest);
|
||||
}
|
||||
if (dump_cpt) {
|
||||
cout << "--- CPT Tables ---" << endl;
|
||||
clf->dump_cpt();
|
||||
}
|
||||
total_score_train += score_train;
|
||||
total_score += score_test;
|
||||
cout << "Score Train: " << score_train << endl;
|
||||
|
Reference in New Issue
Block a user