Refactor cross_validation

This commit is contained in:
2023-07-29 16:44:07 +02:00
parent b9e76becce
commit adc0ca238f
3 changed files with 55 additions and 49 deletions

View File

@@ -105,9 +105,9 @@ namespace platform {
Experiment& setDuration(float duration) { this->duration = duration; return *this; }
string get_file_name();
void save(string path);
Result cross_validation(const string& path, const string& fileName);
//Result cross_validation(const string& path, const string& fileName);
Result cross_validation(string model_name, torch::Tensor& X, torch::Tensor& y, vector<string> features, string className, map<string, vector<int>> states);
void show();
};
Result cross_validation(Fold* fold, string model_name, torch::Tensor& X, torch::Tensor& y, vector<string> features, string className, map<string, vector<int>> states);
}
#endif