refactor fit parameters

This commit is contained in:
2023-11-11 11:19:33 +01:00
parent b6a3a05020
commit a3bf97e501
3 changed files with 58 additions and 40 deletions

View File

@@ -35,7 +35,7 @@ namespace pywrap {
{
return pyWrap->callMethodString(id, method);
}
PyClassifier& PyClassifier::fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states)
PyClassifier& PyClassifier::fit(torch::Tensor& X, torch::Tensor& y)
{
if (!fitted && hyperparameters.size() > 0) {
pyWrap->setHyperparameters(id, hyperparameters);
@@ -47,6 +47,10 @@ namespace pywrap {
fitted = true;
return *this;
}
PyClassifier& PyClassifier::fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states)
{
return fit(X, y);
}
torch::Tensor PyClassifier::predict(torch::Tensor& X)
{
int dimension = X.size(1);

View File

@@ -18,6 +18,7 @@ namespace pywrap {
PyClassifier(const std::string& module, const std::string& className);
virtual ~PyClassifier();
PyClassifier& fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states);
PyClassifier& fit(torch::Tensor& X, torch::Tensor& y);
torch::Tensor predict(torch::Tensor& X);
double score(torch::Tensor& X, torch::Tensor& y);
std::string version();