Add changeModel to b_manage

This commit is contained in:
2025-02-04 17:34:00 +01:00
parent cbe8f4c79c
commit 73a4b3d5e5
4 changed files with 37 additions and 6 deletions

View File

@@ -90,7 +90,7 @@ cmake_path(SET TEST_DATA_PATH "${CMAKE_CURRENT_SOURCE_DIR}/tests/data")
configure_file(src/common/SourceData.h.in "${CMAKE_BINARY_DIR}/configured_files/include/SourceData.h") configure_file(src/common/SourceData.h.in "${CMAKE_BINARY_DIR}/configured_files/include/SourceData.h")
add_subdirectory(config) add_subdirectory(config)
add_subdirectory(src) add_subdirectory(src)
# add_subdirectory(sample) add_subdirectory(sample)
file(GLOB Platform_SOURCES CONFIGURE_DEPENDS ${Platform_SOURCE_DIR}/src/*.cpp) file(GLOB Platform_SOURCES CONFIGURE_DEPENDS ${Platform_SOURCE_DIR}/src/*.cpp)
# Testing # Testing

View File

@@ -219,14 +219,14 @@ int main(int argc, char** argv)
auto [Xtrain, ytrain] = extract_indices(train, Xd, y); auto [Xtrain, ytrain] = extract_indices(train, Xd, y);
auto [Xtest, ytest] = extract_indices(test, Xd, y); auto [Xtest, ytest] = extract_indices(test, Xd, y);
clf->fit(Xtrain, ytrain, features, className, states, smoothing); clf->fit(Xtrain, ytrain, features, className, states, smoothing);
std::cout << "Nodes: " << clf->getNumberOfNodes() << std::endl; std::cout << "Nodes: " << clf->geb_managetNumberOfNodes() << std::endl;
nodes += clf->getNumberOfNodes(); nodes += clf->getNumberOfNodes();
score_train = clf->score(Xtrain, ytrain); score_train = clf->score(Xtrain, ytrain);
score_test = clf->score(Xtest, ytest); score_test = clf->score(Xtest, ytest);
} }
if (dump_cpt) { if (dump_cpt) {
std::cout << "--- CPT Tables ---" << std::endl; std::cout << "--- CPT Tables ---" << std::endl;
clf->dump_cpt(); std::cout << clf->dump_cpt();
} }
total_score_train += score_train; total_score_train += score_train;
total_score += score_test; total_score += score_test;

View File

@@ -312,6 +312,34 @@ namespace platform {
return "Reporting " + results.at(index).getFilename(); return "Reporting " + results.at(index).getFilename();
} }
} }
void ManageScreen::changeModel(const int index)
{
std::cout << "Old model: " << results.at(index).getModel() << std::endl;
std::cout << "New model: ";
std::string newModel;
getline(std::cin, newModel);
if (newModel.empty()) {
list("Model not changed", Colors::YELLOW());
return;
}
if (newModel == results.at(index).getModel()) {
list("Model already set to " + newModel, Colors::RED());
return;
}
// Remove the old result file
std::string oldFile = Paths::results() + results.at(index).getFilename();
std::filesystem::remove(oldFile);
// Actually change the model
results.at(index).setModel(newModel);
results.at(index).save();
int newModelSize = static_cast<int>(newModel.size());
if (newModelSize > maxModel) {
maxModel = newModelSize;
header_lengths[2] = maxModel;
updateSize(rows, cols);
}
list("Model changed to " + newModel, Colors::GREEN());
}
std::pair<std::string, std::string> ManageScreen::sortList() std::pair<std::string, std::string> ManageScreen::sortList()
{ {
std::vector<std::tuple<std::string, char, bool>> sortOptions = { std::vector<std::tuple<std::string, char, bool>> sortOptions = {
@@ -372,6 +400,7 @@ namespace platform {
{"list", 'l', false}, {"list", 'l', false},
{"Delete", 'D', true}, {"Delete", 'D', true},
{"datasets", 'd', false}, {"datasets", 'd', false},
{"change model", 'm', true},
{"hide", 'h', true}, {"hide", 'h', true},
{"sort", 's', false}, {"sort", 's', false},
{"report", 'r', true}, {"report", 'r', true},
@@ -498,6 +527,9 @@ namespace platform {
paginator[static_cast<int>(OutputType::EXPERIMENTS)].setTotal(results.size()); paginator[static_cast<int>(OutputType::EXPERIMENTS)].setTotal(results.size());
list(filename + " deleted!", Colors::RED()); list(filename + " deleted!", Colors::RED());
break; break;
case 'm':
changeModel(index);
break;
case 'h': case 'h':
{ {
std::string status_message; std::string status_message;
@@ -544,7 +576,6 @@ namespace platform {
break; break;
case 't': case 't':
{ {
std::string status_message;
std::cout << "Title: " << results.at(index).getTitle() << std::endl; std::cout << "Title: " << results.at(index).getTitle() << std::endl;
std::cout << "New title: "; std::cout << "New title: ";
std::string newTitle; std::string newTitle;
@@ -552,8 +583,7 @@ namespace platform {
if (!newTitle.empty()) { if (!newTitle.empty()) {
results.at(index).setTitle(newTitle); results.at(index).setTitle(newTitle);
results.at(index).save(); results.at(index).save();
status_message = "Title changed to " + newTitle; list("Title changed to " + newTitle, Colors::GREEN());
list(status_message, Colors::GREEN());
break; break;
} }
list("No title change!", Colors::YELLOW()); list("No title change!", Colors::YELLOW());

View File

@@ -27,6 +27,7 @@ namespace platform {
void list_datasets(const std::string& status, const std::string& color); void list_datasets(const std::string& status, const std::string& color);
bool confirmAction(const std::string& intent, const std::string& fileName) const; bool confirmAction(const std::string& intent, const std::string& fileName) const;
std::string report(const int index, const bool excelReport); std::string report(const int index, const bool excelReport);
void changeModel(const int index);
std::string report_compared(); std::string report_compared();
std::pair<std::string, std::string> sortList(); std::pair<std::string, std::string> sortList();
std::string getVersions(); std::string getVersions();