tests compile but not link
Some checks failed
CI/CD Pipeline / Code Linting (push) Failing after 21s
CI/CD Pipeline / Build and Test (Debug, clang, ubuntu-latest) (push) Failing after 5m17s
CI/CD Pipeline / Build and Test (Debug, gcc, ubuntu-latest) (push) Failing after 5m58s
CI/CD Pipeline / Build and Test (Release, clang, ubuntu-20.04) (push) Failing after 6m4s
CI/CD Pipeline / Build and Test (Release, clang, ubuntu-latest) (push) Failing after 5m16s
CI/CD Pipeline / Build and Test (Release, gcc, ubuntu-20.04) (push) Failing after 5m51s
CI/CD Pipeline / Build and Test (Release, gcc, ubuntu-latest) (push) Failing after 6m0s
CI/CD Pipeline / Docker Build Test (push) Failing after 1m0s
CI/CD Pipeline / Performance Benchmarks (push) Has been skipped
CI/CD Pipeline / Build Documentation (push) Failing after 25s
CI/CD Pipeline / Create Release Package (push) Has been skipped
Some checks failed
CI/CD Pipeline / Code Linting (push) Failing after 21s
CI/CD Pipeline / Build and Test (Debug, clang, ubuntu-latest) (push) Failing after 5m17s
CI/CD Pipeline / Build and Test (Debug, gcc, ubuntu-latest) (push) Failing after 5m58s
CI/CD Pipeline / Build and Test (Release, clang, ubuntu-20.04) (push) Failing after 6m4s
CI/CD Pipeline / Build and Test (Release, clang, ubuntu-latest) (push) Failing after 5m16s
CI/CD Pipeline / Build and Test (Release, gcc, ubuntu-20.04) (push) Failing after 5m51s
CI/CD Pipeline / Build and Test (Release, gcc, ubuntu-latest) (push) Failing after 6m0s
CI/CD Pipeline / Docker Build Test (push) Failing after 1m0s
CI/CD Pipeline / Performance Benchmarks (push) Has been skipped
CI/CD Pipeline / Build Documentation (push) Failing after 25s
CI/CD Pipeline / Create Release Package (push) Has been skipped
This commit is contained in:
@@ -16,9 +16,10 @@ using namespace svm_classifier;
|
||||
*/
|
||||
std::pair<torch::Tensor, torch::Tensor> generate_multiclass_data(int n_samples = 60,
|
||||
int n_features = 2,
|
||||
int n_classes = 3)
|
||||
int n_classes = 3,
|
||||
int seed = 42)
|
||||
{
|
||||
torch::manual_seed(42);
|
||||
torch::manual_seed(seed);
|
||||
|
||||
auto X = torch::randn({ n_samples, n_features });
|
||||
auto y = torch::randint(0, n_classes, { n_samples });
|
||||
@@ -123,7 +124,7 @@ TEST_CASE("OneVsRestStrategy Prediction", "[unit][multiclass_strategy]")
|
||||
{
|
||||
auto predictions = strategy.predict(X_test, converter);
|
||||
|
||||
REQUIRE(predictions.size() == X_test.size(0));
|
||||
REQUIRE(static_cast<int64_t>(predictions.size()) == X_test.size(0));
|
||||
|
||||
// Check that all predictions are valid class labels
|
||||
auto classes = strategy.get_classes();
|
||||
@@ -136,8 +137,8 @@ TEST_CASE("OneVsRestStrategy Prediction", "[unit][multiclass_strategy]")
|
||||
{
|
||||
auto decision_values = strategy.decision_function(X_test, converter);
|
||||
|
||||
REQUIRE(decision_values.size() == X_test.size(0));
|
||||
REQUIRE(decision_values[0].size() == strategy.get_n_classes());
|
||||
REQUIRE(static_cast<int64_t>(decision_values.size()) == X_test.size(0));
|
||||
REQUIRE(static_cast<int>(decision_values[0].size()) == strategy.get_n_classes());
|
||||
|
||||
// Decision values should be real numbers
|
||||
for (const auto& sample_decisions : decision_values) {
|
||||
@@ -173,7 +174,7 @@ TEST_CASE("OneVsRestStrategy Probability Prediction", "[unit][multiclass_strateg
|
||||
if (strategy.supports_probability()) {
|
||||
auto probabilities = strategy.predict_proba(X, converter);
|
||||
|
||||
REQUIRE(probabilities.size() == X.size(0));
|
||||
REQUIRE(static_cast<int64_t>(probabilities.size()) == X.size(0));
|
||||
REQUIRE(probabilities[0].size() == 3); // 3 classes
|
||||
|
||||
// Check probability constraints
|
||||
@@ -269,7 +270,7 @@ TEST_CASE("OneVsOneStrategy Prediction", "[unit][multiclass_strategy]")
|
||||
{
|
||||
auto predictions = strategy.predict(X_test, converter);
|
||||
|
||||
REQUIRE(predictions.size() == X_test.size(0));
|
||||
REQUIRE(static_cast<int64_t>(predictions.size()) == X_test.size(0));
|
||||
|
||||
auto classes = strategy.get_classes();
|
||||
for (int pred : predictions) {
|
||||
@@ -281,7 +282,7 @@ TEST_CASE("OneVsOneStrategy Prediction", "[unit][multiclass_strategy]")
|
||||
{
|
||||
auto decision_values = strategy.decision_function(X_test, converter);
|
||||
|
||||
REQUIRE(decision_values.size() == X_test.size(0));
|
||||
REQUIRE(static_cast<int64_t>(decision_values.size()) == X_test.size(0));
|
||||
|
||||
// For 3 classes, OvO should have C(3,2) = 3 pairwise comparisons
|
||||
REQUIRE(decision_values[0].size() == 3);
|
||||
@@ -298,7 +299,7 @@ TEST_CASE("OneVsOneStrategy Prediction", "[unit][multiclass_strategy]")
|
||||
// OvO probability estimation is more complex
|
||||
auto probabilities = strategy.predict_proba(X_test, converter);
|
||||
|
||||
REQUIRE(probabilities.size() == X_test.size(0));
|
||||
REQUIRE(static_cast<int64_t>(probabilities.size()) == X_test.size(0));
|
||||
REQUIRE(probabilities[0].size() == 3); // 3 classes
|
||||
|
||||
// Check basic probability constraints
|
||||
@@ -395,7 +396,7 @@ TEST_CASE("MulticlassStrategy Edge Cases", "[unit][multiclass_strategy]")
|
||||
// Implementation might extend to binary case
|
||||
|
||||
auto predictions = strategy.predict(X, converter);
|
||||
REQUIRE(predictions.size() == X.size(0));
|
||||
REQUIRE(static_cast<int64_t>(predictions.size()) == X.size(0));
|
||||
}
|
||||
|
||||
SECTION("Very small dataset")
|
||||
@@ -432,7 +433,7 @@ TEST_CASE("MulticlassStrategy Edge Cases", "[unit][multiclass_strategy]")
|
||||
REQUIRE(strategy.get_n_classes() == 2);
|
||||
|
||||
auto predictions = strategy.predict(X, converter);
|
||||
REQUIRE(predictions.size() == X.size(0));
|
||||
REQUIRE(static_cast<int64_t>(predictions.size()) == X.size(0));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -510,7 +511,7 @@ TEST_CASE("MulticlassStrategy Memory Management", "[unit][multiclass_strategy]")
|
||||
REQUIRE(metrics.status == TrainingStatus::SUCCESS);
|
||||
|
||||
auto predictions = strategy.predict(X, converter);
|
||||
REQUIRE(predictions.size() == X.size(0));
|
||||
REQUIRE(static_cast<int64_t>(predictions.size()) == X.size(0));
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user