2024-04-11 16:02:49 +00:00
// ***************************************************************
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
// SPDX-FileType: SOURCE
// SPDX-License-Identifier: MIT
// ***************************************************************
2024-04-07 22:55:30 +00:00
# include <catch2/catch_test_macros.hpp>
# include <catch2/matchers/catch_matchers.hpp>
# include <string>
# include "TestUtils.h"
# include "bayesnet/classifiers/TAN.h"
2024-04-08 20:30:55 +00:00
# include "bayesnet/classifiers/KDB.h"
# include "bayesnet/classifiers/KDBLd.h"
2024-04-07 22:55:30 +00:00
TEST_CASE ( " Test Cannot build dataset with wrong data vector " , " [Classifier] " )
{
auto model = bayesnet : : TAN ( ) ;
auto raw = RawDatasets ( " iris " , true ) ;
raw . yv . pop_back ( ) ;
2024-06-11 09:40:45 +00:00
REQUIRE_THROWS_AS ( model . fit ( raw . Xv , raw . yv , raw . features , raw . className , raw . states , raw . smoothing ) , std : : runtime_error ) ;
REQUIRE_THROWS_WITH ( model . fit ( raw . Xv , raw . yv , raw . features , raw . className , raw . states , raw . smoothing ) , " * Error in X and y dimensions * \n X dimensions: [4, 150] \n y dimensions: [149] " ) ;
2024-04-07 22:55:30 +00:00
}
TEST_CASE ( " Test Cannot build dataset with wrong data tensor " , " [Classifier] " )
{
auto model = bayesnet : : TAN ( ) ;
auto raw = RawDatasets ( " iris " , true ) ;
auto yshort = torch : : zeros ( { 149 } , torch : : kInt32 ) ;
2024-06-11 09:40:45 +00:00
REQUIRE_THROWS_AS ( model . fit ( raw . Xt , yshort , raw . features , raw . className , raw . states , raw . smoothing ) , std : : runtime_error ) ;
REQUIRE_THROWS_WITH ( model . fit ( raw . Xt , yshort , raw . features , raw . className , raw . states , raw . smoothing ) , " * Error in X and y dimensions * \n X dimensions: [4, 150] \n y dimensions: [149] " ) ;
2024-04-07 23:25:14 +00:00
}
TEST_CASE ( " Invalid data type " , " [Classifier] " )
{
auto model = bayesnet : : TAN ( ) ;
auto raw = RawDatasets ( " iris " , false ) ;
2024-06-11 09:40:45 +00:00
REQUIRE_THROWS_AS ( model . fit ( raw . Xt , raw . yt , raw . features , raw . className , raw . states , raw . smoothing ) , std : : invalid_argument ) ;
REQUIRE_THROWS_WITH ( model . fit ( raw . Xt , raw . yt , raw . features , raw . className , raw . states , raw . smoothing ) , " dataset (X, y) must be of type Integer " ) ;
2024-04-07 23:25:14 +00:00
}
TEST_CASE ( " Invalid number of features " , " [Classifier] " )
{
auto model = bayesnet : : TAN ( ) ;
auto raw = RawDatasets ( " iris " , true ) ;
auto Xt = torch : : cat ( { raw . Xt , torch : : zeros ( { 1 , 150 } , torch : : kInt32 ) } , 0 ) ;
2024-06-11 09:40:45 +00:00
REQUIRE_THROWS_AS ( model . fit ( Xt , raw . yt , raw . features , raw . className , raw . states , raw . smoothing ) , std : : invalid_argument ) ;
REQUIRE_THROWS_WITH ( model . fit ( Xt , raw . yt , raw . features , raw . className , raw . states , raw . smoothing ) , " Classifier: X 5 and features 4 must have the same number of features " ) ;
2024-04-07 23:25:14 +00:00
}
TEST_CASE ( " Invalid class name " , " [Classifier] " )
{
auto model = bayesnet : : TAN ( ) ;
auto raw = RawDatasets ( " iris " , true ) ;
2024-06-11 09:40:45 +00:00
REQUIRE_THROWS_AS ( model . fit ( raw . Xt , raw . yt , raw . features , " duck " , raw . states , raw . smoothing ) , std : : invalid_argument ) ;
REQUIRE_THROWS_WITH ( model . fit ( raw . Xt , raw . yt , raw . features , " duck " , raw . states , raw . smoothing ) , " class name not found in states " ) ;
2024-04-07 23:25:14 +00:00
}
TEST_CASE ( " Invalid feature name " , " [Classifier] " )
{
auto model = bayesnet : : TAN ( ) ;
auto raw = RawDatasets ( " iris " , true ) ;
2024-04-29 22:52:09 +00:00
auto statest = raw . states ;
2024-04-07 23:25:14 +00:00
statest . erase ( " petallength " ) ;
2024-06-11 09:40:45 +00:00
REQUIRE_THROWS_AS ( model . fit ( raw . Xt , raw . yt , raw . features , raw . className , statest , raw . smoothing ) , std : : invalid_argument ) ;
REQUIRE_THROWS_WITH ( model . fit ( raw . Xt , raw . yt , raw . features , raw . className , statest , raw . smoothing ) , " feature [petallength] not found in states " ) ;
2024-04-07 23:25:14 +00:00
}
2024-04-11 15:29:46 +00:00
TEST_CASE ( " Invalid hyperparameter " , " [Classifier] " )
{
auto model = bayesnet : : KDB ( 2 ) ;
auto raw = RawDatasets ( " iris " , true ) ;
REQUIRE_THROWS_AS ( model . setHyperparameters ( { { " alpha " , " 0.0 " } } ) , std : : invalid_argument ) ;
REQUIRE_THROWS_WITH ( model . setHyperparameters ( { { " alpha " , " 0.0 " } } ) , " Invalid hyperparameters{ \" alpha \" : \" 0.0 \" } " ) ;
}
2024-04-07 23:25:14 +00:00
TEST_CASE ( " Topological order " , " [Classifier] " )
{
auto model = bayesnet : : TAN ( ) ;
auto raw = RawDatasets ( " iris " , true ) ;
2024-06-11 09:40:45 +00:00
model . fit ( raw . Xt , raw . yt , raw . features , raw . className , raw . states , raw . smoothing ) ;
2024-04-07 23:25:14 +00:00
auto order = model . topological_order ( ) ;
REQUIRE ( order . size ( ) = = 4 ) ;
REQUIRE ( order [ 0 ] = = " petallength " ) ;
REQUIRE ( order [ 1 ] = = " sepallength " ) ;
REQUIRE ( order [ 2 ] = = " sepalwidth " ) ;
REQUIRE ( order [ 3 ] = = " petalwidth " ) ;
}
2024-04-11 16:16:06 +00:00
TEST_CASE ( " Dump_cpt " , " [Classifier] " )
{
auto model = bayesnet : : TAN ( ) ;
auto raw = RawDatasets ( " iris " , true ) ;
2024-06-11 09:40:45 +00:00
model . fit ( raw . Xt , raw . yt , raw . features , raw . className , raw . states , raw . smoothing ) ;
2024-04-11 16:16:06 +00:00
auto cpt = model . dump_cpt ( ) ;
2024-11-23 17:22:41 +00:00
REQUIRE ( cpt . size ( ) = = 1718 ) ;
2024-04-11 16:16:06 +00:00
}
2024-04-07 23:25:14 +00:00
TEST_CASE ( " Not fitted model " , " [Classifier] " )
{
auto model = bayesnet : : TAN ( ) ;
auto raw = RawDatasets ( " iris " , true ) ;
auto message = " Classifier has not been fitted " ;
// tensors
REQUIRE_THROWS_AS ( model . predict ( raw . Xt ) , std : : logic_error ) ;
REQUIRE_THROWS_WITH ( model . predict ( raw . Xt ) , message ) ;
REQUIRE_THROWS_AS ( model . predict_proba ( raw . Xt ) , std : : logic_error ) ;
REQUIRE_THROWS_WITH ( model . predict_proba ( raw . Xt ) , message ) ;
REQUIRE_THROWS_AS ( model . score ( raw . Xt , raw . yt ) , std : : logic_error ) ;
REQUIRE_THROWS_WITH ( model . score ( raw . Xt , raw . yt ) , message ) ;
// vectors
REQUIRE_THROWS_AS ( model . predict ( raw . Xv ) , std : : logic_error ) ;
REQUIRE_THROWS_WITH ( model . predict ( raw . Xv ) , message ) ;
REQUIRE_THROWS_AS ( model . predict_proba ( raw . Xv ) , std : : logic_error ) ;
REQUIRE_THROWS_WITH ( model . predict_proba ( raw . Xv ) , message ) ;
REQUIRE_THROWS_AS ( model . score ( raw . Xv , raw . yv ) , std : : logic_error ) ;
REQUIRE_THROWS_WITH ( model . score ( raw . Xv , raw . yv ) , message ) ;
2024-04-08 20:30:55 +00:00
}
TEST_CASE ( " KDB Graph " , " [Classifier] " )
{
auto model = bayesnet : : KDB ( 2 ) ;
auto raw = RawDatasets ( " iris " , true ) ;
2024-06-11 09:40:45 +00:00
model . fit ( raw . Xv , raw . yv , raw . features , raw . className , raw . states , raw . smoothing ) ;
2024-04-08 20:30:55 +00:00
auto graph = model . graph ( ) ;
REQUIRE ( graph . size ( ) = = 15 ) ;
}
TEST_CASE ( " KDBLd Graph " , " [Classifier] " )
{
auto model = bayesnet : : KDBLd ( 2 ) ;
auto raw = RawDatasets ( " iris " , false ) ;
2024-06-11 09:40:45 +00:00
model . fit ( raw . Xt , raw . yt , raw . features , raw . className , raw . states , raw . smoothing ) ;
2024-04-08 20:30:55 +00:00
auto graph = model . graph ( ) ;
REQUIRE ( graph . size ( ) = = 15 ) ;
2024-04-07 22:55:30 +00:00
}