30 lines
934 B
C++
30 lines
934 B
C++
#ifndef PYCLASSIFER_H
|
|
#define PYCLASSIFER_H
|
|
#include <string>
|
|
#include <map>
|
|
#include <vector>
|
|
#include <utility>
|
|
#include <torch/torch.h>
|
|
#include <boost/python/numpy.hpp>
|
|
#include "PyWrap.h"
|
|
|
|
namespace pywrap {
|
|
|
|
class PyClassifier {
|
|
public:
|
|
PyClassifier(const std::string& module, const std::string& className);
|
|
virtual ~PyClassifier();
|
|
PyClassifier& fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states);
|
|
torch::Tensor predict(torch::Tensor& X);
|
|
double score(torch::Tensor& X, torch::Tensor& y);
|
|
std::string version();
|
|
std::string graph();
|
|
std::string callMethodString(const std::string& method);
|
|
private:
|
|
PyWrap* pyWrap;
|
|
std::string module;
|
|
std::string className;
|
|
};
|
|
|
|
} /* namespace pywrap */
|
|
#endif /* PYCLASSIFER_H */ |