19 Classifier& fit(std::vector<std::vector<int>>& X, std::vector<int>& y,
const std::vector<std::string>& features,
const std::string& className, std::map<std::string, std::vector<int>>& states)
override;
20 Classifier& fit(torch::Tensor& X, torch::Tensor& y,
const std::vector<std::string>& features,
const std::string& className, std::map<std::string, std::vector<int>>& states)
override;
21 Classifier& fit(torch::Tensor& dataset,
const std::vector<std::string>& features,
const std::string& className, std::map<std::string, std::vector<int>>& states)
override;
22 Classifier& fit(torch::Tensor& dataset,
const std::vector<std::string>& features,
const std::string& className, std::map<std::string, std::vector<int>>& states,
const torch::Tensor& weights)
override;
24 int getNumberOfNodes()
const override;
25 int getNumberOfEdges()
const override;
26 int getNumberOfStates()
const override;
27 int getClassNumStates()
const override;
28 torch::Tensor predict(torch::Tensor& X)
override;
29 std::vector<int> predict(std::vector<std::vector<int>>& X)
override;
30 torch::Tensor predict_proba(torch::Tensor& X)
override;
31 std::vector<std::vector<double>> predict_proba(std::vector<std::vector<int>>& X)
override;
32 status_t getStatus()
const override {
return status; }
33 std::string getVersion()
override {
return { project_version.begin(), project_version.end() }; };
34 float score(torch::Tensor& X, torch::Tensor& y)
override;
35 float score(std::vector<std::vector<int>>& X, std::vector<int>& y)
override;
36 std::vector<std::string> show()
const override;
37 std::vector<std::string> topological_order()
override;
38 std::vector<std::string> getNotes()
const override {
return notes; }
39 std::string dump_cpt()
const override;
40 void setHyperparameters(
const nlohmann::json& hyperparameters)
override;
46 std::vector<std::string> features;
47 std::string className;
48 std::map<std::string, std::vector<int>> states;
49 torch::Tensor dataset;
50 status_t status = NORMAL;
51 std::vector<std::string> notes;
52 void checkFitParameters();
53 virtual void buildModel(
const torch::Tensor& weights) = 0;
54 void trainModel(
const torch::Tensor& weights)
override;
55 void buildDataset(torch::Tensor& y);
57 Classifier& build(
const std::vector<std::string>& features,
const std::string& className, std::map<std::string, std::vector<int>>& states,
const torch::Tensor& weights);