Fix XSpode
This commit is contained in:
@@ -45,7 +45,7 @@ namespace bayesnet {
|
|||||||
n = X.size();
|
n = X.size();
|
||||||
buildModel(weights_);
|
buildModel(weights_);
|
||||||
trainModel(weights_, smoothing);
|
trainModel(weights_, smoothing);
|
||||||
fitted=true;
|
fitted = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// --------------------------------------
|
// --------------------------------------
|
||||||
@@ -264,7 +264,7 @@ namespace bayesnet {
|
|||||||
normalize(probs);
|
normalize(probs);
|
||||||
return probs;
|
return probs;
|
||||||
}
|
}
|
||||||
std::vector<std::vector<double>> XSpode::predict_proba(std::vector<std::vector<int>>& test_data)
|
std::vector<std::vector<double>> XSpode::predict_proba(std::vector<std::vector<int>>& test_data)
|
||||||
{
|
{
|
||||||
int test_size = test_data[0].size();
|
int test_size = test_data[0].size();
|
||||||
int sample_size = test_data.size();
|
int sample_size = test_data.size();
|
||||||
@@ -397,22 +397,15 @@ namespace bayesnet {
|
|||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
torch::Tensor XSpode::predict_proba(torch::Tensor& X)
|
torch::Tensor XSpode::predict_proba(torch::Tensor& X)
|
||||||
{
|
|
||||||
auto X_ = TensorUtils::to_matrix(X);
|
|
||||||
auto result_v = predict_proba(X_);
|
|
||||||
torch::Tensor result;
|
|
||||||
for (int i = 0; i < result_v.size(); ++i) {
|
|
||||||
result.index_put_({ i, "..." }, torch::tensor(result_v[i], torch::kDouble));
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
torch::Tensor XSpode::predict(torch::Tensor& X)
|
|
||||||
{
|
{
|
||||||
auto X_ = TensorUtils::to_matrix(X);
|
auto X_ = TensorUtils::to_matrix(X);
|
||||||
auto predict = predict(X_);
|
auto result_v = predict_proba(X_);
|
||||||
return TensorUtils::to_tensor(predict);
|
torch::Tensor result;
|
||||||
|
for (int i = 0; i < result_v.size(); ++i) {
|
||||||
|
result.index_put_({ i, "..." }, torch::tensor(result_v[i], torch::kDouble));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -28,7 +28,7 @@ namespace bayesnet {
|
|||||||
int getNumberOfStates() const override;
|
int getNumberOfStates() const override;
|
||||||
int getClassNumStates() const override;
|
int getClassNumStates() const override;
|
||||||
std::vector<int>& getStates();
|
std::vector<int>& getStates();
|
||||||
std::vector<std::string> graph(const std::string& title) const override { return std::vector<std::string>({title}); }
|
std::vector<std::string> graph(const std::string& title) const override { return std::vector<std::string>({ title }); }
|
||||||
void fit(std::vector<std::vector<int>>& X, std::vector<int>& y, torch::Tensor& weights_, const Smoothing_t smoothing);
|
void fit(std::vector<std::vector<int>>& X, std::vector<int>& y, torch::Tensor& weights_, const Smoothing_t smoothing);
|
||||||
void setHyperparameters(const nlohmann::json& hyperparameters_) override;
|
void setHyperparameters(const nlohmann::json& hyperparameters_) override;
|
||||||
|
|
||||||
@@ -38,7 +38,6 @@ namespace bayesnet {
|
|||||||
torch::Tensor predict(torch::Tensor& X) override;
|
torch::Tensor predict(torch::Tensor& X) override;
|
||||||
std::vector<int> predict(std::vector<std::vector<int>>& X) override;
|
std::vector<int> predict(std::vector<std::vector<int>>& X) override;
|
||||||
torch::Tensor predict_proba(torch::Tensor& X) override;
|
torch::Tensor predict_proba(torch::Tensor& X) override;
|
||||||
std::vector<std::vector<double>> predict_proba(std::vector<std::vector<int>>& X) override;
|
|
||||||
protected:
|
protected:
|
||||||
void buildModel(const torch::Tensor& weights) override;
|
void buildModel(const torch::Tensor& weights) override;
|
||||||
void trainModel(const torch::Tensor& weights, const bayesnet::Smoothing_t smoothing) override;
|
void trainModel(const torch::Tensor& weights, const bayesnet::Smoothing_t smoothing) override;
|
||||||
|
1
lib/catch2
Submodule
1
lib/catch2
Submodule
Submodule lib/catch2 added at 029fe3b460
Submodule lib/folding updated: 9652853d69...2ac43e32ac
Submodule tests/lib/catch2 updated: 0321d2fce3...506276c592
Reference in New Issue
Block a user