2024-04-11 16:02:49 +00:00
// ***************************************************************
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
// SPDX-FileType: SOURCE
// SPDX-License-Identifier: MIT
// ***************************************************************
2024-04-07 22:55:30 +00:00
# include <sstream>
2024-03-08 21:20:54 +00:00
# include "bayesnet/utils/bayesnetUtils.h"
2023-07-22 21:07:56 +00:00
# include "Classifier.h"
2023-07-13 01:15:42 +00:00
namespace bayesnet {
2023-07-22 21:07:56 +00:00
Classifier : : Classifier ( Network model ) : model ( model ) , m ( 0 ) , n ( 0 ) , metrics ( Metrics ( ) ) , fitted ( false ) { }
2024-02-22 10:45:40 +00:00
const std : : string CLASSIFIER_NOT_FITTED = " Classifier has not been fitted " ;
2023-11-08 17:45:35 +00:00
Classifier & Classifier : : build ( const std : : vector < std : : string > & features , const std : : string & className , std : : map < std : : string , std : : vector < int > > & states , const torch : : Tensor & weights )
2023-07-13 01:15:42 +00:00
{
this - > features = features ;
this - > className = className ;
this - > states = states ;
2023-08-07 11:50:11 +00:00
m = dataset . size ( 1 ) ;
2024-03-20 22:33:02 +00:00
n = features . size ( ) ;
2023-07-13 01:15:42 +00:00
checkFitParameters ( ) ;
2023-09-02 11:58:12 +00:00
auto n_classes = states . at ( className ) . size ( ) ;
2023-08-07 10:49:37 +00:00
metrics = Metrics ( dataset , features , className , n_classes ) ;
2023-08-03 18:22:33 +00:00
model . initialize ( ) ;
2023-08-15 13:04:56 +00:00
buildModel ( weights ) ;
trainModel ( weights ) ;
2023-07-14 23:05:36 +00:00
fitted = true ;
2023-07-13 01:15:42 +00:00
return * this ;
}
2023-11-08 17:45:35 +00:00
void Classifier : : buildDataset ( torch : : Tensor & ytmp )
2023-08-07 10:49:37 +00:00
{
2023-08-07 11:50:11 +00:00
try {
auto yresized = torch : : transpose ( ytmp . view ( { ytmp . size ( 0 ) , 1 } ) , 0 , 1 ) ;
dataset = torch : : cat ( { dataset , yresized } , 0 ) ;
}
catch ( const std : : exception & e ) {
2024-04-07 22:55:30 +00:00
std : : stringstream oss ;
oss < < " * Error in X and y dimensions * \n " ;
oss < < " X dimensions: " < < dataset . sizes ( ) < < " \n " ;
oss < < " y dimensions: " < < ytmp . sizes ( ) ;
throw std : : runtime_error ( oss . str ( ) ) ;
2023-08-07 11:50:11 +00:00
}
2023-08-07 10:49:37 +00:00
}
2023-08-15 13:04:56 +00:00
void Classifier : : trainModel ( const torch : : Tensor & weights )
2023-08-07 10:49:37 +00:00
{
2023-08-12 22:59:02 +00:00
model . fit ( dataset , weights , features , className , states ) ;
2023-08-07 10:49:37 +00:00
}
2023-08-03 18:22:33 +00:00
// X is nxm where n is the number of features and m the number of samples
2023-11-08 17:45:35 +00:00
Classifier & Classifier : : fit ( torch : : Tensor & X , torch : : Tensor & y , const std : : vector < std : : string > & features , const std : : string & className , std : : map < std : : string , std : : vector < int > > & states )
2023-07-23 12:10:28 +00:00
{
2023-08-07 10:49:37 +00:00
dataset = X ;
buildDataset ( y ) ;
2023-08-16 10:46:09 +00:00
const torch : : Tensor weights = torch : : full ( { dataset . size ( 1 ) } , 1.0 / dataset . size ( 1 ) , torch : : kDouble ) ;
2023-08-15 13:59:56 +00:00
return build ( features , className , states , weights ) ;
2023-07-23 12:10:28 +00:00
}
2023-08-07 10:49:37 +00:00
// X is nxm where n is the number of features and m the number of samples
2023-11-08 17:45:35 +00:00
Classifier & Classifier : : fit ( std : : vector < std : : vector < int > > & X , std : : vector < int > & y , const std : : vector < std : : string > & features , const std : : string & className , std : : map < std : : string , std : : vector < int > > & states )
2023-08-05 12:40:42 +00:00
{
2023-11-08 17:45:35 +00:00
dataset = torch : : zeros ( { static_cast < int > ( X . size ( ) ) , static_cast < int > ( X [ 0 ] . size ( ) ) } , torch : : kInt32 ) ;
2023-08-07 10:49:37 +00:00
for ( int i = 0 ; i < X . size ( ) ; + + i ) {
2023-11-08 17:45:35 +00:00
dataset . index_put_ ( { i , " ... " } , torch : : tensor ( X [ i ] , torch : : kInt32 ) ) ;
2023-08-05 12:40:42 +00:00
}
2023-11-08 17:45:35 +00:00
auto ytmp = torch : : tensor ( y , torch : : kInt32 ) ;
2023-08-07 10:49:37 +00:00
buildDataset ( ytmp ) ;
2023-08-16 10:46:09 +00:00
const torch : : Tensor weights = torch : : full ( { dataset . size ( 1 ) } , 1.0 / dataset . size ( 1 ) , torch : : kDouble ) ;
2023-08-15 13:59:56 +00:00
return build ( features , className , states , weights ) ;
2023-08-05 12:40:42 +00:00
}
2023-11-08 17:45:35 +00:00
Classifier & Classifier : : fit ( torch : : Tensor & dataset , const std : : vector < std : : string > & features , const std : : string & className , std : : map < std : : string , std : : vector < int > > & states )
2023-07-13 01:15:42 +00:00
{
2023-08-07 10:49:37 +00:00
this - > dataset = dataset ;
2023-08-16 10:46:09 +00:00
const torch : : Tensor weights = torch : : full ( { dataset . size ( 1 ) } , 1.0 / dataset . size ( 1 ) , torch : : kDouble ) ;
2023-08-15 13:59:56 +00:00
return build ( features , className , states , weights ) ;
}
2023-11-08 17:45:35 +00:00
Classifier & Classifier : : fit ( torch : : Tensor & dataset , const std : : vector < std : : string > & features , const std : : string & className , std : : map < std : : string , std : : vector < int > > & states , const torch : : Tensor & weights )
2023-08-15 13:59:56 +00:00
{
this - > dataset = dataset ;
return build ( features , className , states , weights ) ;
2023-07-13 01:15:42 +00:00
}
2023-07-22 21:07:56 +00:00
void Classifier : : checkFitParameters ( )
2023-07-13 01:15:42 +00:00
{
2023-08-24 10:09:35 +00:00
if ( torch : : is_floating_point ( dataset ) ) {
2023-11-08 17:45:35 +00:00
throw std : : invalid_argument ( " dataset (X, y) must be of type Integer " ) ;
2023-08-24 10:09:35 +00:00
}
2024-04-07 23:25:14 +00:00
if ( dataset . size ( 0 ) - 1 ! = features . size ( ) ) {
throw std : : invalid_argument ( " Classifier: X " + std : : to_string ( dataset . size ( 0 ) - 1 ) + " and features " + std : : to_string ( features . size ( ) ) + " must have the same number of features " ) ;
2023-07-13 01:15:42 +00:00
}
if ( states . find ( className ) = = states . end ( ) ) {
2024-04-07 23:25:14 +00:00
throw std : : invalid_argument ( " class name not found in states " ) ;
2023-07-13 01:15:42 +00:00
}
for ( auto feature : features ) {
if ( states . find ( feature ) = = states . end ( ) ) {
2023-11-08 17:45:35 +00:00
throw std : : invalid_argument ( " feature [ " + feature + " ] not found in states " ) ;
2023-07-13 01:15:42 +00:00
}
}
}
2023-11-08 17:45:35 +00:00
torch : : Tensor Classifier : : predict ( torch : : Tensor & X )
2023-07-13 01:15:42 +00:00
{
2023-07-14 23:05:36 +00:00
if ( ! fitted ) {
2024-02-22 10:45:40 +00:00
throw std : : logic_error ( CLASSIFIER_NOT_FITTED ) ;
2023-07-13 01:15:42 +00:00
}
2023-08-03 18:22:33 +00:00
return model . predict ( X ) ;
2023-07-13 01:15:42 +00:00
}
2023-11-08 17:45:35 +00:00
std : : vector < int > Classifier : : predict ( std : : vector < std : : vector < int > > & X )
2023-07-14 23:59:30 +00:00
{
if ( ! fitted ) {
2024-02-22 10:45:40 +00:00
throw std : : logic_error ( CLASSIFIER_NOT_FITTED ) ;
2023-07-14 23:59:30 +00:00
}
auto m_ = X [ 0 ] . size ( ) ;
auto n_ = X . size ( ) ;
2023-11-08 17:45:35 +00:00
std : : vector < std : : vector < int > > Xd ( n_ , std : : vector < int > ( m_ , 0 ) ) ;
2023-07-14 23:59:30 +00:00
for ( auto i = 0 ; i < n_ ; i + + ) {
2023-11-08 17:45:35 +00:00
Xd [ i ] = std : : vector < int > ( X [ i ] . begin ( ) , X [ i ] . end ( ) ) ;
2023-07-14 23:59:30 +00:00
}
auto yp = model . predict ( Xd ) ;
return yp ;
}
2024-02-22 10:45:40 +00:00
torch : : Tensor Classifier : : predict_proba ( torch : : Tensor & X )
{
if ( ! fitted ) {
throw std : : logic_error ( CLASSIFIER_NOT_FITTED ) ;
}
return model . predict_proba ( X ) ;
}
std : : vector < std : : vector < double > > Classifier : : predict_proba ( std : : vector < std : : vector < int > > & X )
{
if ( ! fitted ) {
throw std : : logic_error ( CLASSIFIER_NOT_FITTED ) ;
}
auto m_ = X [ 0 ] . size ( ) ;
auto n_ = X . size ( ) ;
std : : vector < std : : vector < int > > Xd ( n_ , std : : vector < int > ( m_ , 0 ) ) ;
2024-02-23 19:36:11 +00:00
// Convert to nxm vector
2024-02-22 10:45:40 +00:00
for ( auto i = 0 ; i < n_ ; i + + ) {
Xd [ i ] = std : : vector < int > ( X [ i ] . begin ( ) , X [ i ] . end ( ) ) ;
}
auto yp = model . predict_proba ( Xd ) ;
return yp ;
}
2023-11-08 17:45:35 +00:00
float Classifier : : score ( torch : : Tensor & X , torch : : Tensor & y )
2023-07-13 01:15:42 +00:00
{
2023-11-08 17:45:35 +00:00
torch : : Tensor y_pred = predict ( X ) ;
2023-07-13 01:15:42 +00:00
return ( y_pred = = y ) . sum ( ) . item < float > ( ) / y . size ( 0 ) ;
}
2023-11-08 17:45:35 +00:00
float Classifier : : score ( std : : vector < std : : vector < int > > & X , std : : vector < int > & y )
2023-07-14 23:05:36 +00:00
{
if ( ! fitted ) {
2024-02-22 10:45:40 +00:00
throw std : : logic_error ( CLASSIFIER_NOT_FITTED ) ;
2023-07-14 23:05:36 +00:00
}
2023-07-30 17:00:02 +00:00
return model . score ( X , y ) ;
2023-07-14 23:05:36 +00:00
}
2023-11-08 17:45:35 +00:00
std : : vector < std : : string > Classifier : : show ( ) const
2023-07-13 22:10:55 +00:00
{
return model . show ( ) ;
}
2023-07-22 21:07:56 +00:00
void Classifier : : addNodes ( )
2023-07-13 14:59:06 +00:00
{
2023-07-13 22:10:55 +00:00
// Add all nodes to the network
2023-08-03 18:22:33 +00:00
for ( const auto & feature : features ) {
2023-08-05 12:40:42 +00:00
model . addNode ( feature ) ;
2023-07-13 22:10:55 +00:00
}
2023-08-05 12:40:42 +00:00
model . addNode ( className ) ;
2023-07-13 14:59:06 +00:00
}
2023-08-07 23:53:41 +00:00
int Classifier : : getNumberOfNodes ( ) const
2023-07-19 13:05:44 +00:00
{
// Features does not include class
2023-10-23 20:46:10 +00:00
return fitted ? model . getFeatures ( ) . size ( ) : 0 ;
2023-07-19 13:05:44 +00:00
}
2023-08-07 23:53:41 +00:00
int Classifier : : getNumberOfEdges ( ) const
2023-07-19 13:05:44 +00:00
{
2023-08-07 23:53:41 +00:00
return fitted ? model . getNumEdges ( ) : 0 ;
2023-07-19 13:05:44 +00:00
}
2023-08-07 23:53:41 +00:00
int Classifier : : getNumberOfStates ( ) const
2023-07-26 17:01:39 +00:00
{
return fitted ? model . getStates ( ) : 0 ;
}
2024-02-22 10:45:40 +00:00
int Classifier : : getClassNumStates ( ) const
{
return fitted ? model . getClassNumStates ( ) : 0 ;
}
2023-11-08 17:45:35 +00:00
std : : vector < std : : string > Classifier : : topological_order ( )
2023-08-01 22:56:52 +00:00
{
return model . topological_sort ( ) ;
}
2024-04-07 23:25:14 +00:00
std : : string Classifier : : dump_cpt ( ) const
2023-08-03 18:22:33 +00:00
{
2024-04-07 23:25:14 +00:00
return model . dump_cpt ( ) ;
2023-08-03 18:22:33 +00:00
}
2023-11-18 10:56:10 +00:00
void Classifier : : setHyperparameters ( const nlohmann : : json & hyperparameters )
2023-08-24 10:09:35 +00:00
{
2024-04-07 22:55:30 +00:00
if ( ! hyperparameters . empty ( ) ) {
throw std : : invalid_argument ( " Invalid hyperparameters " + hyperparameters . dump ( ) ) ;
}
2023-08-24 10:09:35 +00:00
}
2023-07-13 01:15:42 +00:00
}