From 235367da01d91f22db2ee1edf860e1cecd9a0b6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Mon, 23 Jun 2025 12:19:22 +0200 Subject: [PATCH] tests compile but not link --- external/CMakeLists.txt | 28 +++++++++++++++++----------- tests/test_multiclass_strategy.cpp | 25 +++++++++++++------------ 2 files changed, 30 insertions(+), 23 deletions(-) diff --git a/external/CMakeLists.txt b/external/CMakeLists.txt index 4446312..41430c3 100644 --- a/external/CMakeLists.txt +++ b/external/CMakeLists.txt @@ -64,19 +64,25 @@ elseif(EXISTS "${liblinear_SOURCE_DIR}/newton.cpp") list(APPEND LIBLINEAR_SOURCES "${liblinear_SOURCE_DIR}/newton.cpp") endif() -# Check for BLAS files in blas directory -file(GLOB BLAS_C_FILES "${liblinear_SOURCE_DIR}/blas/*.c") -file(GLOB BLAS_CPP_FILES "${liblinear_SOURCE_DIR}/blas/*.cpp") - -if(BLAS_C_FILES OR BLAS_CPP_FILES) - list(APPEND LIBLINEAR_SOURCES ${BLAS_C_FILES} ${BLAS_CPP_FILES}) -else() - # Try alternative BLAS file names - foreach(blas_file daxpy ddot dnrm2 dscal) - if(EXISTS "${liblinear_SOURCE_DIR}/blas/${blas_file}.c") - list(APPEND LIBLINEAR_SOURCES "${liblinear_SOURCE_DIR}/blas/${blas_file}.c") +# Check for BLAS files in blas directory - Fixed to ensure they're included +if(EXISTS "${liblinear_SOURCE_DIR}/blas") + # Explicitly add the required BLAS source files + set(BLAS_FILES + "${liblinear_SOURCE_DIR}/blas/daxpy.c" + "${liblinear_SOURCE_DIR}/blas/ddot.c" + "${liblinear_SOURCE_DIR}/blas/dnrm2.c" + "${liblinear_SOURCE_DIR}/blas/dscal.c" + ) + + # Check which BLAS files actually exist and add them + foreach(blas_file ${BLAS_FILES}) + if(EXISTS ${blas_file}) + list(APPEND LIBLINEAR_SOURCES ${blas_file}) + message(STATUS "Adding BLAS file: ${blas_file}") endif() endforeach() +else() + message(WARNING "BLAS directory not found in liblinear source") endif() # Create liblinear object library if we have source files diff --git a/tests/test_multiclass_strategy.cpp b/tests/test_multiclass_strategy.cpp index 184b966..ed36c86 100644 --- a/tests/test_multiclass_strategy.cpp +++ b/tests/test_multiclass_strategy.cpp @@ -16,9 +16,10 @@ using namespace svm_classifier; */ std::pair generate_multiclass_data(int n_samples = 60, int n_features = 2, - int n_classes = 3) + int n_classes = 3, + int seed = 42) { - torch::manual_seed(42); + torch::manual_seed(seed); auto X = torch::randn({ n_samples, n_features }); auto y = torch::randint(0, n_classes, { n_samples }); @@ -123,7 +124,7 @@ TEST_CASE("OneVsRestStrategy Prediction", "[unit][multiclass_strategy]") { auto predictions = strategy.predict(X_test, converter); - REQUIRE(predictions.size() == X_test.size(0)); + REQUIRE(static_cast(predictions.size()) == X_test.size(0)); // Check that all predictions are valid class labels auto classes = strategy.get_classes(); @@ -136,8 +137,8 @@ TEST_CASE("OneVsRestStrategy Prediction", "[unit][multiclass_strategy]") { auto decision_values = strategy.decision_function(X_test, converter); - REQUIRE(decision_values.size() == X_test.size(0)); - REQUIRE(decision_values[0].size() == strategy.get_n_classes()); + REQUIRE(static_cast(decision_values.size()) == X_test.size(0)); + REQUIRE(static_cast(decision_values[0].size()) == strategy.get_n_classes()); // Decision values should be real numbers for (const auto& sample_decisions : decision_values) { @@ -173,7 +174,7 @@ TEST_CASE("OneVsRestStrategy Probability Prediction", "[unit][multiclass_strateg if (strategy.supports_probability()) { auto probabilities = strategy.predict_proba(X, converter); - REQUIRE(probabilities.size() == X.size(0)); + REQUIRE(static_cast(probabilities.size()) == X.size(0)); REQUIRE(probabilities[0].size() == 3); // 3 classes // Check probability constraints @@ -269,7 +270,7 @@ TEST_CASE("OneVsOneStrategy Prediction", "[unit][multiclass_strategy]") { auto predictions = strategy.predict(X_test, converter); - REQUIRE(predictions.size() == X_test.size(0)); + REQUIRE(static_cast(predictions.size()) == X_test.size(0)); auto classes = strategy.get_classes(); for (int pred : predictions) { @@ -281,7 +282,7 @@ TEST_CASE("OneVsOneStrategy Prediction", "[unit][multiclass_strategy]") { auto decision_values = strategy.decision_function(X_test, converter); - REQUIRE(decision_values.size() == X_test.size(0)); + REQUIRE(static_cast(decision_values.size()) == X_test.size(0)); // For 3 classes, OvO should have C(3,2) = 3 pairwise comparisons REQUIRE(decision_values[0].size() == 3); @@ -298,7 +299,7 @@ TEST_CASE("OneVsOneStrategy Prediction", "[unit][multiclass_strategy]") // OvO probability estimation is more complex auto probabilities = strategy.predict_proba(X_test, converter); - REQUIRE(probabilities.size() == X_test.size(0)); + REQUIRE(static_cast(probabilities.size()) == X_test.size(0)); REQUIRE(probabilities[0].size() == 3); // 3 classes // Check basic probability constraints @@ -395,7 +396,7 @@ TEST_CASE("MulticlassStrategy Edge Cases", "[unit][multiclass_strategy]") // Implementation might extend to binary case auto predictions = strategy.predict(X, converter); - REQUIRE(predictions.size() == X.size(0)); + REQUIRE(static_cast(predictions.size()) == X.size(0)); } SECTION("Very small dataset") @@ -432,7 +433,7 @@ TEST_CASE("MulticlassStrategy Edge Cases", "[unit][multiclass_strategy]") REQUIRE(strategy.get_n_classes() == 2); auto predictions = strategy.predict(X, converter); - REQUIRE(predictions.size() == X.size(0)); + REQUIRE(static_cast(predictions.size()) == X.size(0)); } } @@ -510,7 +511,7 @@ TEST_CASE("MulticlassStrategy Memory Management", "[unit][multiclass_strategy]") REQUIRE(metrics.status == TrainingStatus::SUCCESS); auto predictions = strategy.predict(X, converter); - REQUIRE(predictions.size() == X.size(0)); + REQUIRE(static_cast(predictions.size()) == X.size(0)); } } } \ No newline at end of file