refactor fit parameters

This commit is contained in:
2023-11-11 11:19:33 +01:00
parent b6a3a05020
commit a3bf97e501
3 changed files with 58 additions and 40 deletions

View File

@@ -18,6 +18,7 @@ namespace pywrap {
PyClassifier(const std::string& module, const std::string& className);
virtual ~PyClassifier();
PyClassifier& fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states);
PyClassifier& fit(torch::Tensor& X, torch::Tensor& y);
torch::Tensor predict(torch::Tensor& X);
double score(torch::Tensor& X, torch::Tensor& y);
std::string version();