refactor fit parameters
This commit is contained in:
@@ -35,7 +35,7 @@ namespace pywrap {
|
||||
{
|
||||
return pyWrap->callMethodString(id, method);
|
||||
}
|
||||
PyClassifier& PyClassifier::fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states)
|
||||
PyClassifier& PyClassifier::fit(torch::Tensor& X, torch::Tensor& y)
|
||||
{
|
||||
if (!fitted && hyperparameters.size() > 0) {
|
||||
pyWrap->setHyperparameters(id, hyperparameters);
|
||||
@@ -47,6 +47,10 @@ namespace pywrap {
|
||||
fitted = true;
|
||||
return *this;
|
||||
}
|
||||
PyClassifier& PyClassifier::fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states)
|
||||
{
|
||||
return fit(X, y);
|
||||
}
|
||||
torch::Tensor PyClassifier::predict(torch::Tensor& X)
|
||||
{
|
||||
int dimension = X.size(1);
|
||||
|
@@ -18,6 +18,7 @@ namespace pywrap {
|
||||
PyClassifier(const std::string& module, const std::string& className);
|
||||
virtual ~PyClassifier();
|
||||
PyClassifier& fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states);
|
||||
PyClassifier& fit(torch::Tensor& X, torch::Tensor& y);
|
||||
torch::Tensor predict(torch::Tensor& X);
|
||||
double score(torch::Tensor& X, torch::Tensor& y);
|
||||
std::string version();
|
||||
|
Reference in New Issue
Block a user