Add entropy, conditionalEntropy, mutualInformation and conditionalEdgeWeight methods
This commit is contained in:
parent
3750662f2c
commit
c7e2042c6e
4
.vscode/settings.json
vendored
4
.vscode/settings.json
vendored
@ -88,7 +88,9 @@
|
||||
"iterator": "cpp",
|
||||
"memory_resource": "cpp",
|
||||
"format": "cpp",
|
||||
"valarray": "cpp"
|
||||
"valarray": "cpp",
|
||||
"regex": "cpp",
|
||||
"span": "cpp"
|
||||
},
|
||||
"cmake.configureOnOpen": false,
|
||||
"C_Cpp.default.configurationProvider": "ms-vscode.cmake-tools"
|
||||
|
428
data/ecoli.arff
Executable file
428
data/ecoli.arff
Executable file
@ -0,0 +1,428 @@
|
||||
%
|
||||
% 1. Title: Protein Localization Sites
|
||||
%
|
||||
%
|
||||
% 2. Creator and Maintainer:
|
||||
% Kenta Nakai
|
||||
% Institue of Molecular and Cellular Biology
|
||||
% Osaka, University
|
||||
% 1-3 Yamada-oka, Suita 565 Japan
|
||||
% nakai@imcb.osaka-u.ac.jp
|
||||
% http://www.imcb.osaka-u.ac.jp/nakai/psort.html
|
||||
% Donor: Paul Horton (paulh@cs.berkeley.edu)
|
||||
% Date: September, 1996
|
||||
% See also: yeast database
|
||||
%
|
||||
% 3. Past Usage.
|
||||
% Reference: "A Probablistic Classification System for Predicting the Cellular
|
||||
% Localization Sites of Proteins", Paul Horton & Kenta Nakai,
|
||||
% Intelligent Systems in Molecular Biology, 109-115.
|
||||
% St. Louis, USA 1996.
|
||||
% Results: 81% for E.coli with an ad hoc structured
|
||||
% probability model. Also similar accuracy for Binary Decision Tree and
|
||||
% Bayesian Classifier methods applied by the same authors in
|
||||
% unpublished results.
|
||||
%
|
||||
% Predicted Attribute: Localization site of protein. ( non-numeric ).
|
||||
%
|
||||
%
|
||||
% 4. The references below describe a predecessor to this dataset and its
|
||||
% development. They also give results (not cross-validated) for classification
|
||||
% by a rule-based expert system with that version of the dataset.
|
||||
%
|
||||
% Reference: "Expert Sytem for Predicting Protein Localization Sites in
|
||||
% Gram-Negative Bacteria", Kenta Nakai & Minoru Kanehisa,
|
||||
% PROTEINS: Structure, Function, and Genetics 11:95-110, 1991.
|
||||
%
|
||||
% Reference: "A Knowledge Base for Predicting Protein Localization Sites in
|
||||
% Eukaryotic Cells", Kenta Nakai & Minoru Kanehisa,
|
||||
% Genomics 14:897-911, 1992.
|
||||
%
|
||||
%
|
||||
% 5. Number of Instances: 336 for the E.coli dataset and
|
||||
%
|
||||
%
|
||||
% 6. Number of Attributes.
|
||||
% for E.coli dataset: 8 ( 7 predictive, 1 name )
|
||||
%
|
||||
% 7. Attribute Information.
|
||||
%
|
||||
% 1. Sequence Name: Accession number for the SWISS-PROT database
|
||||
% 2. mcg: McGeoch's method for signal sequence recognition.
|
||||
% 3. gvh: von Heijne's method for signal sequence recognition.
|
||||
% 4. lip: von Heijne's Signal Peptidase II consensus sequence score.
|
||||
% Binary attribute.
|
||||
% 5. chg: Presence of charge on N-terminus of predicted lipoproteins.
|
||||
% Binary attribute.
|
||||
% 6. aac: score of discriminant analysis of the amino acid content of
|
||||
% outer membrane and periplasmic proteins.
|
||||
% 7. alm1: score of the ALOM membrane spanning region prediction program.
|
||||
% 8. alm2: score of ALOM program after excluding putative cleavable signal
|
||||
% regions from the sequence.
|
||||
%
|
||||
% NOTE - the sequence name has been removed
|
||||
%
|
||||
% 8. Missing Attribute Values: None.
|
||||
%
|
||||
%
|
||||
% 9. Class Distribution. The class is the localization site. Please see Nakai &
|
||||
% Kanehisa referenced above for more details.
|
||||
%
|
||||
% cp (cytoplasm) 143
|
||||
% im (inner membrane without signal sequence) 77
|
||||
% pp (perisplasm) 52
|
||||
% imU (inner membrane, uncleavable signal sequence) 35
|
||||
% om (outer membrane) 20
|
||||
% omL (outer membrane lipoprotein) 5
|
||||
% imL (inner membrane lipoprotein) 2
|
||||
% imS (inner membrane, cleavable signal sequence) 2
|
||||
|
||||
@relation ecoli
|
||||
|
||||
@attribute mcg numeric
|
||||
@attribute gvh numeric
|
||||
@attribute lip numeric
|
||||
@attribute chg numeric
|
||||
@attribute aac numeric
|
||||
@attribute alm1 numeric
|
||||
@attribute alm2 numeric
|
||||
@attribute class {cp,im,pp,imU,om,omL,imL,imS}
|
||||
|
||||
@data
|
||||
|
||||
0.49,0.29,0.48,0.5,0.56,0.24,0.35,cp
|
||||
0.07,0.4,0.48,0.5,0.54,0.35,0.44,cp
|
||||
0.56,0.4,0.48,0.5,0.49,0.37,0.46,cp
|
||||
0.59,0.49,0.48,0.5,0.52,0.45,0.36,cp
|
||||
0.23,0.32,0.48,0.5,0.55,0.25,0.35,cp
|
||||
0.67,0.39,0.48,0.5,0.36,0.38,0.46,cp
|
||||
0.29,0.28,0.48,0.5,0.44,0.23,0.34,cp
|
||||
0.21,0.34,0.48,0.5,0.51,0.28,0.39,cp
|
||||
0.2,0.44,0.48,0.5,0.46,0.51,0.57,cp
|
||||
0.42,0.4,0.48,0.5,0.56,0.18,0.3,cp
|
||||
0.42,0.24,0.48,0.5,0.57,0.27,0.37,cp
|
||||
0.25,0.48,0.48,0.5,0.44,0.17,0.29,cp
|
||||
0.39,0.32,0.48,0.5,0.46,0.24,0.35,cp
|
||||
0.51,0.5,0.48,0.5,0.46,0.32,0.35,cp
|
||||
0.22,0.43,0.48,0.5,0.48,0.16,0.28,cp
|
||||
0.25,0.4,0.48,0.5,0.46,0.44,0.52,cp
|
||||
0.34,0.45,0.48,0.5,0.38,0.24,0.35,cp
|
||||
0.44,0.27,0.48,0.5,0.55,0.52,0.58,cp
|
||||
0.23,0.4,0.48,0.5,0.39,0.28,0.38,cp
|
||||
0.41,0.57,0.48,0.5,0.39,0.21,0.32,cp
|
||||
0.4,0.45,0.48,0.5,0.38,0.22,0,cp
|
||||
0.31,0.23,0.48,0.5,0.73,0.05,0.14,cp
|
||||
0.51,0.54,0.48,0.5,0.41,0.34,0.43,cp
|
||||
0.3,0.16,0.48,0.5,0.56,0.11,0.23,cp
|
||||
0.36,0.39,0.48,0.5,0.48,0.22,0.23,cp
|
||||
0.29,0.37,0.48,0.5,0.48,0.44,0.52,cp
|
||||
0.25,0.4,0.48,0.5,0.47,0.33,0.42,cp
|
||||
0.21,0.51,0.48,0.5,0.5,0.32,0.41,cp
|
||||
0.43,0.37,0.48,0.5,0.53,0.35,0.44,cp
|
||||
0.43,0.39,0.48,0.5,0.47,0.31,0.41,cp
|
||||
0.53,0.38,0.48,0.5,0.44,0.26,0.36,cp
|
||||
0.34,0.33,0.48,0.5,0.38,0.35,0.44,cp
|
||||
0.56,0.51,0.48,0.5,0.34,0.37,0.46,cp
|
||||
0.4,0.29,0.48,0.5,0.42,0.35,0.44,cp
|
||||
0.24,0.35,0.48,0.5,0.31,0.19,0.31,cp
|
||||
0.36,0.54,0.48,0.5,0.41,0.38,0.46,cp
|
||||
0.29,0.52,0.48,0.5,0.42,0.29,0.39,cp
|
||||
0.65,0.47,0.48,0.5,0.59,0.3,0.4,cp
|
||||
0.32,0.42,0.48,0.5,0.35,0.28,0.38,cp
|
||||
0.38,0.46,0.48,0.5,0.48,0.22,0.29,cp
|
||||
0.33,0.45,0.48,0.5,0.52,0.32,0.41,cp
|
||||
0.3,0.37,0.48,0.5,0.59,0.41,0.49,cp
|
||||
0.4,0.5,0.48,0.5,0.45,0.39,0.47,cp
|
||||
0.28,0.38,0.48,0.5,0.5,0.33,0.42,cp
|
||||
0.61,0.45,0.48,0.5,0.48,0.35,0.41,cp
|
||||
0.17,0.38,0.48,0.5,0.45,0.42,0.5,cp
|
||||
0.44,0.35,0.48,0.5,0.55,0.55,0.61,cp
|
||||
0.43,0.4,0.48,0.5,0.39,0.28,0.39,cp
|
||||
0.42,0.35,0.48,0.5,0.58,0.15,0.27,cp
|
||||
0.23,0.33,0.48,0.5,0.43,0.33,0.43,cp
|
||||
0.37,0.52,0.48,0.5,0.42,0.42,0.36,cp
|
||||
0.29,0.3,0.48,0.5,0.45,0.03,0.17,cp
|
||||
0.22,0.36,0.48,0.5,0.35,0.39,0.47,cp
|
||||
0.23,0.58,0.48,0.5,0.37,0.53,0.59,cp
|
||||
0.47,0.47,0.48,0.5,0.22,0.16,0.26,cp
|
||||
0.54,0.47,0.48,0.5,0.28,0.33,0.42,cp
|
||||
0.51,0.37,0.48,0.5,0.35,0.36,0.45,cp
|
||||
0.4,0.35,0.48,0.5,0.45,0.33,0.42,cp
|
||||
0.44,0.34,0.48,0.5,0.3,0.33,0.43,cp
|
||||
0.42,0.38,0.48,0.5,0.54,0.34,0.43,cp
|
||||
0.44,0.56,0.48,0.5,0.5,0.46,0.54,cp
|
||||
0.52,0.36,0.48,0.5,0.41,0.28,0.38,cp
|
||||
0.36,0.41,0.48,0.5,0.48,0.47,0.54,cp
|
||||
0.18,0.3,0.48,0.5,0.46,0.24,0.35,cp
|
||||
0.47,0.29,0.48,0.5,0.51,0.33,0.43,cp
|
||||
0.24,0.43,0.48,0.5,0.54,0.52,0.59,cp
|
||||
0.25,0.37,0.48,0.5,0.41,0.33,0.42,cp
|
||||
0.52,0.57,0.48,0.5,0.42,0.47,0.54,cp
|
||||
0.25,0.37,0.48,0.5,0.43,0.26,0.36,cp
|
||||
0.35,0.48,0.48,0.5,0.56,0.4,0.48,cp
|
||||
0.26,0.26,0.48,0.5,0.34,0.25,0.35,cp
|
||||
0.44,0.51,0.48,0.5,0.47,0.26,0.36,cp
|
||||
0.37,0.5,0.48,0.5,0.42,0.36,0.45,cp
|
||||
0.44,0.42,0.48,0.5,0.42,0.25,0.2,cp
|
||||
0.24,0.43,0.48,0.5,0.37,0.28,0.38,cp
|
||||
0.42,0.3,0.48,0.5,0.48,0.26,0.36,cp
|
||||
0.48,0.42,0.48,0.5,0.45,0.25,0.35,cp
|
||||
0.41,0.48,0.48,0.5,0.51,0.44,0.51,cp
|
||||
0.44,0.28,0.48,0.5,0.43,0.27,0.37,cp
|
||||
0.29,0.41,0.48,0.5,0.48,0.38,0.46,cp
|
||||
0.34,0.28,0.48,0.5,0.41,0.35,0.44,cp
|
||||
0.41,0.43,0.48,0.5,0.45,0.31,0.41,cp
|
||||
0.29,0.47,0.48,0.5,0.41,0.23,0.34,cp
|
||||
0.34,0.55,0.48,0.5,0.58,0.31,0.41,cp
|
||||
0.36,0.56,0.48,0.5,0.43,0.45,0.53,cp
|
||||
0.4,0.46,0.48,0.5,0.52,0.49,0.56,cp
|
||||
0.5,0.49,0.48,0.5,0.49,0.46,0.53,cp
|
||||
0.52,0.44,0.48,0.5,0.37,0.36,0.42,cp
|
||||
0.5,0.51,0.48,0.5,0.27,0.23,0.34,cp
|
||||
0.53,0.42,0.48,0.5,0.16,0.29,0.39,cp
|
||||
0.34,0.46,0.48,0.5,0.52,0.35,0.44,cp
|
||||
0.4,0.42,0.48,0.5,0.37,0.27,0.27,cp
|
||||
0.41,0.43,0.48,0.5,0.5,0.24,0.25,cp
|
||||
0.3,0.45,0.48,0.5,0.36,0.21,0.32,cp
|
||||
0.31,0.47,0.48,0.5,0.29,0.28,0.39,cp
|
||||
0.64,0.76,0.48,0.5,0.45,0.35,0.38,cp
|
||||
0.35,0.37,0.48,0.5,0.3,0.34,0.43,cp
|
||||
0.57,0.54,0.48,0.5,0.37,0.28,0.33,cp
|
||||
0.65,0.55,0.48,0.5,0.34,0.37,0.28,cp
|
||||
0.51,0.46,0.48,0.5,0.58,0.31,0.41,cp
|
||||
0.38,0.4,0.48,0.5,0.63,0.25,0.35,cp
|
||||
0.24,0.57,0.48,0.5,0.63,0.34,0.43,cp
|
||||
0.38,0.26,0.48,0.5,0.54,0.16,0.28,cp
|
||||
0.33,0.47,0.48,0.5,0.53,0.18,0.29,cp
|
||||
0.24,0.34,0.48,0.5,0.38,0.3,0.4,cp
|
||||
0.26,0.5,0.48,0.5,0.44,0.32,0.41,cp
|
||||
0.44,0.49,0.48,0.5,0.39,0.38,0.4,cp
|
||||
0.43,0.32,0.48,0.5,0.33,0.45,0.52,cp
|
||||
0.49,0.43,0.48,0.5,0.49,0.3,0.4,cp
|
||||
0.47,0.28,0.48,0.5,0.56,0.2,0.25,cp
|
||||
0.32,0.33,0.48,0.5,0.6,0.06,0.2,cp
|
||||
0.34,0.35,0.48,0.5,0.51,0.49,0.56,cp
|
||||
0.35,0.34,0.48,0.5,0.46,0.3,0.27,cp
|
||||
0.38,0.3,0.48,0.5,0.43,0.29,0.39,cp
|
||||
0.38,0.44,0.48,0.5,0.43,0.2,0.31,cp
|
||||
0.41,0.51,0.48,0.5,0.58,0.2,0.31,cp
|
||||
0.34,0.42,0.48,0.5,0.41,0.34,0.43,cp
|
||||
0.51,0.49,0.48,0.5,0.53,0.14,0.26,cp
|
||||
0.25,0.51,0.48,0.5,0.37,0.42,0.5,cp
|
||||
0.29,0.28,0.48,0.5,0.5,0.42,0.5,cp
|
||||
0.25,0.26,0.48,0.5,0.39,0.32,0.42,cp
|
||||
0.24,0.41,0.48,0.5,0.49,0.23,0.34,cp
|
||||
0.17,0.39,0.48,0.5,0.53,0.3,0.39,cp
|
||||
0.04,0.31,0.48,0.5,0.41,0.29,0.39,cp
|
||||
0.61,0.36,0.48,0.5,0.49,0.35,0.44,cp
|
||||
0.34,0.51,0.48,0.5,0.44,0.37,0.46,cp
|
||||
0.28,0.33,0.48,0.5,0.45,0.22,0.33,cp
|
||||
0.4,0.46,0.48,0.5,0.42,0.35,0.44,cp
|
||||
0.23,0.34,0.48,0.5,0.43,0.26,0.37,cp
|
||||
0.37,0.44,0.48,0.5,0.42,0.39,0.47,cp
|
||||
0,0.38,0.48,0.5,0.42,0.48,0.55,cp
|
||||
0.39,0.31,0.48,0.5,0.38,0.34,0.43,cp
|
||||
0.3,0.44,0.48,0.5,0.49,0.22,0.33,cp
|
||||
0.27,0.3,0.48,0.5,0.71,0.28,0.39,cp
|
||||
0.17,0.52,0.48,0.5,0.49,0.37,0.46,cp
|
||||
0.36,0.42,0.48,0.5,0.53,0.32,0.41,cp
|
||||
0.3,0.37,0.48,0.5,0.43,0.18,0.3,cp
|
||||
0.26,0.4,0.48,0.5,0.36,0.26,0.37,cp
|
||||
0.4,0.41,0.48,0.5,0.55,0.22,0.33,cp
|
||||
0.22,0.34,0.48,0.5,0.42,0.29,0.39,cp
|
||||
0.44,0.35,0.48,0.5,0.44,0.52,0.59,cp
|
||||
0.27,0.42,0.48,0.5,0.37,0.38,0.43,cp
|
||||
0.16,0.43,0.48,0.5,0.54,0.27,0.37,cp
|
||||
0.06,0.61,0.48,0.5,0.49,0.92,0.37,im
|
||||
0.44,0.52,0.48,0.5,0.43,0.47,0.54,im
|
||||
0.63,0.47,0.48,0.5,0.51,0.82,0.84,im
|
||||
0.23,0.48,0.48,0.5,0.59,0.88,0.89,im
|
||||
0.34,0.49,0.48,0.5,0.58,0.85,0.8,im
|
||||
0.43,0.4,0.48,0.5,0.58,0.75,0.78,im
|
||||
0.46,0.61,0.48,0.5,0.48,0.86,0.87,im
|
||||
0.27,0.35,0.48,0.5,0.51,0.77,0.79,im
|
||||
0.52,0.39,0.48,0.5,0.65,0.71,0.73,im
|
||||
0.29,0.47,0.48,0.5,0.71,0.65,0.69,im
|
||||
0.55,0.47,0.48,0.5,0.57,0.78,0.8,im
|
||||
0.12,0.67,0.48,0.5,0.74,0.58,0.63,im
|
||||
0.4,0.5,0.48,0.5,0.65,0.82,0.84,im
|
||||
0.73,0.36,0.48,0.5,0.53,0.91,0.92,im
|
||||
0.84,0.44,0.48,0.5,0.48,0.71,0.74,im
|
||||
0.48,0.45,0.48,0.5,0.6,0.78,0.8,im
|
||||
0.54,0.49,0.48,0.5,0.4,0.87,0.88,im
|
||||
0.48,0.41,0.48,0.5,0.51,0.9,0.88,im
|
||||
0.5,0.66,0.48,0.5,0.31,0.92,0.92,im
|
||||
0.72,0.46,0.48,0.5,0.51,0.66,0.7,im
|
||||
0.47,0.55,0.48,0.5,0.58,0.71,0.75,im
|
||||
0.33,0.56,0.48,0.5,0.33,0.78,0.8,im
|
||||
0.64,0.58,0.48,0.5,0.48,0.78,0.73,im
|
||||
0.54,0.57,0.48,0.5,0.56,0.81,0.83,im
|
||||
0.47,0.59,0.48,0.5,0.52,0.76,0.79,im
|
||||
0.63,0.5,0.48,0.5,0.59,0.85,0.86,im
|
||||
0.49,0.42,0.48,0.5,0.53,0.79,0.81,im
|
||||
0.31,0.5,0.48,0.5,0.57,0.84,0.85,im
|
||||
0.74,0.44,0.48,0.5,0.55,0.88,0.89,im
|
||||
0.33,0.45,0.48,0.5,0.45,0.88,0.89,im
|
||||
0.45,0.4,0.48,0.5,0.61,0.74,0.77,im
|
||||
0.71,0.4,0.48,0.5,0.71,0.7,0.74,im
|
||||
0.5,0.37,0.48,0.5,0.66,0.64,0.69,im
|
||||
0.66,0.53,0.48,0.5,0.59,0.66,0.66,im
|
||||
0.6,0.61,0.48,0.5,0.54,0.67,0.71,im
|
||||
0.83,0.37,0.48,0.5,0.61,0.71,0.74,im
|
||||
0.34,0.51,0.48,0.5,0.67,0.9,0.9,im
|
||||
0.63,0.54,0.48,0.5,0.65,0.79,0.81,im
|
||||
0.7,0.4,0.48,0.5,0.56,0.86,0.83,im
|
||||
0.6,0.5,1,0.5,0.54,0.77,0.8,im
|
||||
0.16,0.51,0.48,0.5,0.33,0.39,0.48,im
|
||||
0.74,0.7,0.48,0.5,0.66,0.65,0.69,im
|
||||
0.2,0.46,0.48,0.5,0.57,0.78,0.81,im
|
||||
0.89,0.55,0.48,0.5,0.51,0.72,0.76,im
|
||||
0.7,0.46,0.48,0.5,0.56,0.78,0.73,im
|
||||
0.12,0.43,0.48,0.5,0.63,0.7,0.74,im
|
||||
0.61,0.52,0.48,0.5,0.54,0.67,0.52,im
|
||||
0.33,0.37,0.48,0.5,0.46,0.65,0.69,im
|
||||
0.63,0.65,0.48,0.5,0.66,0.67,0.71,im
|
||||
0.41,0.51,0.48,0.5,0.53,0.75,0.78,im
|
||||
0.34,0.67,0.48,0.5,0.52,0.76,0.79,im
|
||||
0.58,0.34,0.48,0.5,0.56,0.87,0.81,im
|
||||
0.59,0.56,0.48,0.5,0.55,0.8,0.82,im
|
||||
0.51,0.4,0.48,0.5,0.57,0.62,0.67,im
|
||||
0.5,0.57,0.48,0.5,0.71,0.61,0.66,im
|
||||
0.6,0.46,0.48,0.5,0.45,0.81,0.83,im
|
||||
0.37,0.47,0.48,0.5,0.39,0.76,0.79,im
|
||||
0.58,0.55,0.48,0.5,0.57,0.7,0.74,im
|
||||
0.36,0.47,0.48,0.5,0.51,0.69,0.72,im
|
||||
0.39,0.41,0.48,0.5,0.52,0.72,0.75,im
|
||||
0.35,0.51,0.48,0.5,0.61,0.71,0.74,im
|
||||
0.31,0.44,0.48,0.5,0.5,0.79,0.82,im
|
||||
0.61,0.66,0.48,0.5,0.46,0.87,0.88,im
|
||||
0.48,0.49,0.48,0.5,0.52,0.77,0.71,im
|
||||
0.11,0.5,0.48,0.5,0.58,0.72,0.68,im
|
||||
0.31,0.36,0.48,0.5,0.58,0.94,0.94,im
|
||||
0.68,0.51,0.48,0.5,0.71,0.75,0.78,im
|
||||
0.69,0.39,0.48,0.5,0.57,0.76,0.79,im
|
||||
0.52,0.54,0.48,0.5,0.62,0.76,0.79,im
|
||||
0.46,0.59,0.48,0.5,0.36,0.76,0.23,im
|
||||
0.36,0.45,0.48,0.5,0.38,0.79,0.17,im
|
||||
0,0.51,0.48,0.5,0.35,0.67,0.44,im
|
||||
0.1,0.49,0.48,0.5,0.41,0.67,0.21,im
|
||||
0.3,0.51,0.48,0.5,0.42,0.61,0.34,im
|
||||
0.61,0.47,0.48,0.5,0,0.8,0.32,im
|
||||
0.63,0.75,0.48,0.5,0.64,0.73,0.66,im
|
||||
0.71,0.52,0.48,0.5,0.64,1,0.99,im
|
||||
0.85,0.53,0.48,0.5,0.53,0.52,0.35,imS
|
||||
0.63,0.49,0.48,0.5,0.54,0.76,0.79,imS
|
||||
0.75,0.55,1,1,0.4,0.47,0.3,imL
|
||||
0.7,0.39,1,0.5,0.51,0.82,0.84,imL
|
||||
0.72,0.42,0.48,0.5,0.65,0.77,0.79,imU
|
||||
0.79,0.41,0.48,0.5,0.66,0.81,0.83,imU
|
||||
0.83,0.48,0.48,0.5,0.65,0.76,0.79,imU
|
||||
0.69,0.43,0.48,0.5,0.59,0.74,0.77,imU
|
||||
0.79,0.36,0.48,0.5,0.46,0.82,0.7,imU
|
||||
0.78,0.33,0.48,0.5,0.57,0.77,0.79,imU
|
||||
0.75,0.37,0.48,0.5,0.64,0.7,0.74,imU
|
||||
0.59,0.29,0.48,0.5,0.64,0.75,0.77,imU
|
||||
0.67,0.37,0.48,0.5,0.54,0.64,0.68,imU
|
||||
0.66,0.48,0.48,0.5,0.54,0.7,0.74,imU
|
||||
0.64,0.46,0.48,0.5,0.48,0.73,0.76,imU
|
||||
0.76,0.71,0.48,0.5,0.5,0.71,0.75,imU
|
||||
0.84,0.49,0.48,0.5,0.55,0.78,0.74,imU
|
||||
0.77,0.55,0.48,0.5,0.51,0.78,0.74,imU
|
||||
0.81,0.44,0.48,0.5,0.42,0.67,0.68,imU
|
||||
0.58,0.6,0.48,0.5,0.59,0.73,0.76,imU
|
||||
0.63,0.42,0.48,0.5,0.48,0.77,0.8,imU
|
||||
0.62,0.42,0.48,0.5,0.58,0.79,0.81,imU
|
||||
0.86,0.39,0.48,0.5,0.59,0.89,0.9,imU
|
||||
0.81,0.53,0.48,0.5,0.57,0.87,0.88,imU
|
||||
0.87,0.49,0.48,0.5,0.61,0.76,0.79,imU
|
||||
0.47,0.46,0.48,0.5,0.62,0.74,0.77,imU
|
||||
0.76,0.41,0.48,0.5,0.5,0.59,0.62,imU
|
||||
0.7,0.53,0.48,0.5,0.7,0.86,0.87,imU
|
||||
0.64,0.45,0.48,0.5,0.67,0.61,0.66,imU
|
||||
0.81,0.52,0.48,0.5,0.57,0.78,0.8,imU
|
||||
0.73,0.26,0.48,0.5,0.57,0.75,0.78,imU
|
||||
0.49,0.61,1,0.5,0.56,0.71,0.74,imU
|
||||
0.88,0.42,0.48,0.5,0.52,0.73,0.75,imU
|
||||
0.84,0.54,0.48,0.5,0.75,0.92,0.7,imU
|
||||
0.63,0.51,0.48,0.5,0.64,0.72,0.76,imU
|
||||
0.86,0.55,0.48,0.5,0.63,0.81,0.83,imU
|
||||
0.79,0.54,0.48,0.5,0.5,0.66,0.68,imU
|
||||
0.57,0.38,0.48,0.5,0.06,0.49,0.33,imU
|
||||
0.78,0.44,0.48,0.5,0.45,0.73,0.68,imU
|
||||
0.78,0.68,0.48,0.5,0.83,0.4,0.29,om
|
||||
0.63,0.69,0.48,0.5,0.65,0.41,0.28,om
|
||||
0.67,0.88,0.48,0.5,0.73,0.5,0.25,om
|
||||
0.61,0.75,0.48,0.5,0.51,0.33,0.33,om
|
||||
0.67,0.84,0.48,0.5,0.74,0.54,0.37,om
|
||||
0.74,0.9,0.48,0.5,0.57,0.53,0.29,om
|
||||
0.73,0.84,0.48,0.5,0.86,0.58,0.29,om
|
||||
0.75,0.76,0.48,0.5,0.83,0.57,0.3,om
|
||||
0.77,0.57,0.48,0.5,0.88,0.53,0.2,om
|
||||
0.74,0.78,0.48,0.5,0.75,0.54,0.15,om
|
||||
0.68,0.76,0.48,0.5,0.84,0.45,0.27,om
|
||||
0.56,0.68,0.48,0.5,0.77,0.36,0.45,om
|
||||
0.65,0.51,0.48,0.5,0.66,0.54,0.33,om
|
||||
0.52,0.81,0.48,0.5,0.72,0.38,0.38,om
|
||||
0.64,0.57,0.48,0.5,0.7,0.33,0.26,om
|
||||
0.6,0.76,1,0.5,0.77,0.59,0.52,om
|
||||
0.69,0.59,0.48,0.5,0.77,0.39,0.21,om
|
||||
0.63,0.49,0.48,0.5,0.79,0.45,0.28,om
|
||||
0.71,0.71,0.48,0.5,0.68,0.43,0.36,om
|
||||
0.68,0.63,0.48,0.5,0.73,0.4,0.3,om
|
||||
0.77,0.57,1,0.5,0.37,0.54,0.01,omL
|
||||
0.66,0.49,1,0.5,0.54,0.56,0.36,omL
|
||||
0.71,0.46,1,0.5,0.52,0.59,0.3,omL
|
||||
0.67,0.55,1,0.5,0.66,0.58,0.16,omL
|
||||
0.68,0.49,1,0.5,0.62,0.55,0.28,omL
|
||||
0.74,0.49,0.48,0.5,0.42,0.54,0.36,pp
|
||||
0.7,0.61,0.48,0.5,0.56,0.52,0.43,pp
|
||||
0.66,0.86,0.48,0.5,0.34,0.41,0.36,pp
|
||||
0.73,0.78,0.48,0.5,0.58,0.51,0.31,pp
|
||||
0.65,0.57,0.48,0.5,0.47,0.47,0.51,pp
|
||||
0.72,0.86,0.48,0.5,0.17,0.55,0.21,pp
|
||||
0.67,0.7,0.48,0.5,0.46,0.45,0.33,pp
|
||||
0.67,0.81,0.48,0.5,0.54,0.49,0.23,pp
|
||||
0.67,0.61,0.48,0.5,0.51,0.37,0.38,pp
|
||||
0.63,1,0.48,0.5,0.35,0.51,0.49,pp
|
||||
0.57,0.59,0.48,0.5,0.39,0.47,0.33,pp
|
||||
0.71,0.71,0.48,0.5,0.4,0.54,0.39,pp
|
||||
0.66,0.74,0.48,0.5,0.31,0.38,0.43,pp
|
||||
0.67,0.81,0.48,0.5,0.25,0.42,0.25,pp
|
||||
0.64,0.72,0.48,0.5,0.49,0.42,0.19,pp
|
||||
0.68,0.82,0.48,0.5,0.38,0.65,0.56,pp
|
||||
0.32,0.39,0.48,0.5,0.53,0.28,0.38,pp
|
||||
0.7,0.64,0.48,0.5,0.47,0.51,0.47,pp
|
||||
0.63,0.57,0.48,0.5,0.49,0.7,0.2,pp
|
||||
0.74,0.82,0.48,0.5,0.49,0.49,0.41,pp
|
||||
0.63,0.86,0.48,0.5,0.39,0.47,0.34,pp
|
||||
0.63,0.83,0.48,0.5,0.4,0.39,0.19,pp
|
||||
0.63,0.71,0.48,0.5,0.6,0.4,0.39,pp
|
||||
0.71,0.86,0.48,0.5,0.4,0.54,0.32,pp
|
||||
0.68,0.78,0.48,0.5,0.43,0.44,0.42,pp
|
||||
0.64,0.84,0.48,0.5,0.37,0.45,0.4,pp
|
||||
0.74,0.47,0.48,0.5,0.5,0.57,0.42,pp
|
||||
0.75,0.84,0.48,0.5,0.35,0.52,0.33,pp
|
||||
0.63,0.65,0.48,0.5,0.39,0.44,0.35,pp
|
||||
0.69,0.67,0.48,0.5,0.3,0.39,0.24,pp
|
||||
0.7,0.71,0.48,0.5,0.42,0.84,0.85,pp
|
||||
0.69,0.8,0.48,0.5,0.46,0.57,0.26,pp
|
||||
0.64,0.66,0.48,0.5,0.41,0.39,0.2,pp
|
||||
0.63,0.8,0.48,0.5,0.46,0.31,0.29,pp
|
||||
0.66,0.71,0.48,0.5,0.41,0.5,0.35,pp
|
||||
0.69,0.59,0.48,0.5,0.46,0.44,0.52,pp
|
||||
0.68,0.67,0.48,0.5,0.49,0.4,0.34,pp
|
||||
0.64,0.78,0.48,0.5,0.5,0.36,0.38,pp
|
||||
0.62,0.78,0.48,0.5,0.47,0.49,0.54,pp
|
||||
0.76,0.73,0.48,0.5,0.44,0.39,0.39,pp
|
||||
0.64,0.81,0.48,0.5,0.37,0.39,0.44,pp
|
||||
0.29,0.39,0.48,0.5,0.52,0.4,0.48,pp
|
||||
0.62,0.83,0.48,0.5,0.46,0.36,0.4,pp
|
||||
0.56,0.54,0.48,0.5,0.43,0.37,0.3,pp
|
||||
0.69,0.66,0.48,0.5,0.41,0.5,0.25,pp
|
||||
0.69,0.65,0.48,0.5,0.63,0.48,0.41,pp
|
||||
0.43,0.59,0.48,0.5,0.52,0.49,0.56,pp
|
||||
0.74,0.56,0.48,0.5,0.47,0.68,0.3,pp
|
||||
0.71,0.57,0.48,0.5,0.48,0.35,0.32,pp
|
||||
0.61,0.6,0.48,0.5,0.44,0.39,0.38,pp
|
||||
0.59,0.61,0.48,0.5,0.42,0.42,0.37,pp
|
||||
0.74,0.74,0.48,0.5,0.31,0.53,0.52,pp
|
@ -138,6 +138,7 @@ pair<string, string> get_options(int argc, char** argv)
|
||||
{
|
||||
map<string, bool> datasets = {
|
||||
{"diabetes", true},
|
||||
{"ecoli", true},
|
||||
{"glass", true},
|
||||
{"iris", true},
|
||||
{"kdd_JapaneseVowels", false},
|
||||
@ -229,5 +230,6 @@ int main(int argc, char** argv)
|
||||
cout << "BayesNet version: " << network.version() << endl;
|
||||
unsigned int nthreads = std::thread::hardware_concurrency();
|
||||
cout << "Computer has " << nthreads << " cores." << endl;
|
||||
cout << "conditionalEdgeWeight " << endl << network.conditionalEdgeWeight() << endl;
|
||||
return 0;
|
||||
}
|
166
sample/test.cc
166
sample/test.cc
@ -25,36 +25,154 @@
|
||||
#include <vector>
|
||||
#include <string>
|
||||
using namespace std;
|
||||
double entropy(torch::Tensor feature)
|
||||
{
|
||||
torch::Tensor counts = feature.bincount();
|
||||
int totalWeight = counts.sum().item<int>();
|
||||
torch::Tensor probs = counts.to(torch::kFloat) / totalWeight;
|
||||
torch::Tensor logProbs = torch::log2(probs);
|
||||
torch::Tensor entropy = -probs * logProbs;
|
||||
return entropy.sum().item<double>();
|
||||
}
|
||||
// H(Y|X) = sum_{x in X} p(x) H(Y|X=x)
|
||||
double conditionalEntropy(torch::Tensor firstFeature, torch::Tensor secondFeature)
|
||||
{
|
||||
int numSamples = firstFeature.sizes()[0];
|
||||
torch::Tensor featureCounts = secondFeature.bincount();
|
||||
unordered_map<int, unordered_map<int, double>> jointCounts;
|
||||
double totalWeight = 0;
|
||||
for (auto i = 0; i < numSamples; i++) {
|
||||
jointCounts[secondFeature[i].item<int>()][firstFeature[i].item<int>()] += 1;
|
||||
totalWeight += 1;
|
||||
}
|
||||
if (totalWeight == 0)
|
||||
throw invalid_argument("Total weight should not be zero");
|
||||
double entropy = 0;
|
||||
for (int value = 0; value < featureCounts.sizes()[0]; ++value) {
|
||||
double p_f = featureCounts[value].item<double>() / totalWeight;
|
||||
double entropy_f = 0;
|
||||
for (auto& [label, jointCount] : jointCounts[value]) {
|
||||
double p_l_f = jointCount / featureCounts[value].item<double>();
|
||||
if (p_l_f > 0) {
|
||||
entropy_f -= p_l_f * log2(p_l_f);
|
||||
} else {
|
||||
entropy_f = 0;
|
||||
}
|
||||
}
|
||||
entropy += p_f * entropy_f;
|
||||
}
|
||||
return entropy;
|
||||
}
|
||||
|
||||
// I(X;Y) = H(Y) - H(Y|X)
|
||||
double mutualInformation(torch::Tensor firstFeature, torch::Tensor secondFeature)
|
||||
{
|
||||
return entropy(firstFeature) - conditionalEntropy(firstFeature, secondFeature);
|
||||
}
|
||||
double entropy2(torch::Tensor feature)
|
||||
{
|
||||
return torch::special::entr(feature).sum().item<double>();
|
||||
}
|
||||
int main()
|
||||
{
|
||||
//int i = 3, j = 1, k = 2; // Indices for the cell you want to update
|
||||
// Print original tensor
|
||||
// torch::Tensor t = torch::tensor({ {1, 2, 3}, {4, 5, 6} }); // 3D tensor for this example
|
||||
auto variables = vector<string>{ "A", "B" };
|
||||
auto cardinalities = vector<int>{ 5, 4 };
|
||||
torch::Tensor values = torch::rand({ 5, 4 });
|
||||
auto candidate = "B";
|
||||
vector<string> newVariables;
|
||||
vector<int> newCardinalities;
|
||||
for (int i = 0; i < variables.size(); i++) {
|
||||
if (variables[i] != candidate) {
|
||||
newVariables.push_back(variables[i]);
|
||||
newCardinalities.push_back(cardinalities[i]);
|
||||
}
|
||||
}
|
||||
torch::Tensor newValues = values.sum(1);
|
||||
cout << "original values" << endl;
|
||||
cout << values << endl;
|
||||
cout << "newValues" << endl;
|
||||
cout << newValues << endl;
|
||||
cout << "newVariables" << endl;
|
||||
for (auto& variable : newVariables) {
|
||||
cout << variable << endl;
|
||||
}
|
||||
cout << "newCardinalities" << endl;
|
||||
for (auto& cardinality : newCardinalities) {
|
||||
cout << cardinality << endl;
|
||||
// auto variables = vector<string>{ "A", "B" };
|
||||
// auto cardinalities = vector<int>{ 5, 4 };
|
||||
// torch::Tensor values = torch::rand({ 5, 4 });
|
||||
// auto candidate = "B";
|
||||
// vector<string> newVariables;
|
||||
// vector<int> newCardinalities;
|
||||
// for (int i = 0; i < variables.size(); i++) {
|
||||
// if (variables[i] != candidate) {
|
||||
// newVariables.push_back(variables[i]);
|
||||
// newCardinalities.push_back(cardinalities[i]);
|
||||
// }
|
||||
// }
|
||||
// torch::Tensor newValues = values.sum(1);
|
||||
// cout << "original values" << endl;
|
||||
// cout << values << endl;
|
||||
// cout << "newValues" << endl;
|
||||
// cout << newValues << endl;
|
||||
// cout << "newVariables" << endl;
|
||||
// for (auto& variable : newVariables) {
|
||||
// cout << variable << endl;
|
||||
// }
|
||||
// cout << "newCardinalities" << endl;
|
||||
// for (auto& cardinality : newCardinalities) {
|
||||
// cout << cardinality << endl;
|
||||
// }
|
||||
// auto row2 = values.index({ torch::tensor(1) }); //
|
||||
// cout << "row2" << endl;
|
||||
// cout << row2 << endl;
|
||||
// auto col2 = values.index({ "...", 1 });
|
||||
// cout << "col2" << endl;
|
||||
// cout << col2 << endl;
|
||||
// auto col_last = values.index({ "...", -1 });
|
||||
// cout << "col_last" << endl;
|
||||
// cout << col_last << endl;
|
||||
// values.index_put_({ "...", -1 }, torch::tensor({ 1,2,3,4,5 }));
|
||||
// cout << "col_last" << endl;
|
||||
// cout << col_last << endl;
|
||||
// auto slice2 = values.index({ torch::indexing::Slice(1, torch::indexing::None) });
|
||||
// cout << "slice2" << endl;
|
||||
// cout << slice2 << endl;
|
||||
// auto mask = values.index({ "...", -1 }) % 2 == 0;
|
||||
// auto filter = values.index({ mask, 2 }); // Filter values
|
||||
// cout << "filter" << endl;
|
||||
// cout << filter << endl;
|
||||
// torch::Tensor dataset = torch::tensor({ {1,0,0,1},{1,1,1,2},{0,0,0,1},{1,0,2,0},{0,0,3,0} });
|
||||
// cout << "dataset" << endl;
|
||||
// cout << dataset << endl;
|
||||
// cout << "entropy(dataset.indices('...', 2))" << endl;
|
||||
// cout << dataset.index({ "...", 2 }) << endl;
|
||||
// cout << "*********************************" << endl;
|
||||
// for (int i = 0; i < 4; i++) {
|
||||
// cout << "datset(" << i << ")" << endl;
|
||||
// cout << dataset.index({ "...", i }) << endl;
|
||||
// cout << "entropy(" << i << ")" << endl;
|
||||
// cout << entropy(dataset.index({ "...", i })) << endl;
|
||||
// }
|
||||
// cout << "......................................" << endl;
|
||||
// //cout << entropy2(dataset.index({ "...", 2 }));
|
||||
// cout << "conditional entropy 0 2" << endl;
|
||||
// cout << conditionalEntropy(dataset.index({ "...", 0 }), dataset.index({ "...", 2 })) << endl;
|
||||
// cout << "mutualInformation(dataset.index({ '...', 0 }), dataset.index({ '...', 2 }))" << endl;
|
||||
// cout << mutualInformation(dataset.index({ "...", 0 }), dataset.index({ "...", 2 })) << endl;
|
||||
// auto test = torch::tensor({ .1, .2, .3 }, torch::kFloat);
|
||||
// auto result = torch::zeros({ 3, 3 }, torch::kFloat);
|
||||
// result.index_put_({ indices }, test);
|
||||
// cout << "indices" << endl;
|
||||
// cout << indices << endl;
|
||||
// cout << "result" << endl;
|
||||
// cout << result << endl;
|
||||
// cout << "Test" << endl;
|
||||
// cout << torch::triu(test.reshape(3, 3), torch::kFloat)) << endl;
|
||||
|
||||
|
||||
// Create a 3x3 tensor with zeros
|
||||
torch::Tensor tensor_3x3 = torch::zeros({ 3, 3 }, torch::kFloat);
|
||||
|
||||
// Create a 1D tensor with the three elements you want to set in the upper corner
|
||||
torch::Tensor tensor_1d = torch::tensor({ 10, 11, 12 }, torch::kFloat);
|
||||
|
||||
// Set the upper corner of the 3x3 tensor
|
||||
auto indices = torch::triu_indices(3, 3, 1);
|
||||
for (auto i = 0; i < tensor_1d.sizes()[0]; ++i) {
|
||||
auto x = indices[0][i];
|
||||
auto y = indices[1][i];
|
||||
tensor_3x3[x][y] = tensor_1d[i];
|
||||
tensor_3x3[y][x] = tensor_1d[i];
|
||||
}
|
||||
// Print the resulting 3x3 tensor
|
||||
std::cout << tensor_3x3 << std::endl;
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
// std::cout << t << std::endl;
|
||||
// std::cout << "sum(0)" << std::endl;
|
||||
// std::cout << t.sum(0) << std::endl;
|
||||
|
104
src/Network.cc
104
src/Network.cc
@ -98,11 +98,14 @@ namespace bayesnet {
|
||||
this->className = className;
|
||||
dataset.clear();
|
||||
|
||||
// Build dataset
|
||||
// Build dataset & tensor of samples
|
||||
samples = torch::zeros({ static_cast<int64_t>(input_data[0].size()), static_cast<int64_t>(input_data.size() + 1) }, torch::kInt64);
|
||||
for (int i = 0; i < featureNames.size(); ++i) {
|
||||
dataset[featureNames[i]] = input_data[i];
|
||||
samples.index_put_({ "...", i }, torch::tensor(input_data[i], torch::kInt64));
|
||||
}
|
||||
dataset[className] = labels;
|
||||
samples.index_put_({ "...", -1 }, torch::tensor(labels, torch::kInt64));
|
||||
classNumStates = *max_element(labels.begin(), labels.end()) + 1;
|
||||
int maxThreadsRunning = static_cast<int>(std::thread::hardware_concurrency() * maxThreads);
|
||||
if (maxThreadsRunning < 1) {
|
||||
@ -150,14 +153,14 @@ namespace bayesnet {
|
||||
}
|
||||
}
|
||||
|
||||
vector<int> Network::predict(const vector<vector<int>>& samples)
|
||||
vector<int> Network::predict(const vector<vector<int>>& tsamples)
|
||||
{
|
||||
vector<int> predictions;
|
||||
vector<int> sample;
|
||||
for (int row = 0; row < samples[0].size(); ++row) {
|
||||
for (int row = 0; row < tsamples[0].size(); ++row) {
|
||||
sample.clear();
|
||||
for (int col = 0; col < samples.size(); ++col) {
|
||||
sample.push_back(samples[col][row]);
|
||||
for (int col = 0; col < tsamples.size(); ++col) {
|
||||
sample.push_back(tsamples[col][row]);
|
||||
}
|
||||
vector<double> classProbabilities = predict_sample(sample);
|
||||
// Find the class with the maximum posterior probability
|
||||
@ -167,22 +170,22 @@ namespace bayesnet {
|
||||
}
|
||||
return predictions;
|
||||
}
|
||||
vector<vector<double>> Network::predict_proba(const vector<vector<int>>& samples)
|
||||
vector<vector<double>> Network::predict_proba(const vector<vector<int>>& tsamples)
|
||||
{
|
||||
vector<vector<double>> predictions;
|
||||
vector<int> sample;
|
||||
for (int row = 0; row < samples[0].size(); ++row) {
|
||||
for (int row = 0; row < tsamples[0].size(); ++row) {
|
||||
sample.clear();
|
||||
for (int col = 0; col < samples.size(); ++col) {
|
||||
sample.push_back(samples[col][row]);
|
||||
for (int col = 0; col < tsamples.size(); ++col) {
|
||||
sample.push_back(tsamples[col][row]);
|
||||
}
|
||||
predictions.push_back(predict_sample(sample));
|
||||
}
|
||||
return predictions;
|
||||
}
|
||||
double Network::score(const vector<vector<int>>& samples, const vector<int>& labels)
|
||||
double Network::score(const vector<vector<int>>& tsamples, const vector<int>& labels)
|
||||
{
|
||||
vector<int> y_pred = predict(samples);
|
||||
vector<int> y_pred = predict(tsamples);
|
||||
int correct = 0;
|
||||
for (int i = 0; i < y_pred.size(); ++i) {
|
||||
if (y_pred[i] == labels[i]) {
|
||||
@ -238,4 +241,83 @@ namespace bayesnet {
|
||||
}
|
||||
return result;
|
||||
}
|
||||
double Network::mutual_info(torch::Tensor& first, torch::Tensor& second)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
torch::Tensor Network::conditionalEdgeWeight()
|
||||
{
|
||||
auto result = vector<double>();
|
||||
auto source = vector<string>(features);
|
||||
source.push_back(className);
|
||||
auto combinations = nodes[className]->combinations(source);
|
||||
auto margin = nodes[className]->getCPT();
|
||||
for (auto [first, second] : combinations) {
|
||||
int64_t index_first = find(features.begin(), features.end(), first) - features.begin();
|
||||
int64_t index_second = find(features.begin(), features.end(), second) - features.begin();
|
||||
double accumulated = 0;
|
||||
for (int value = 0; value < classNumStates; ++value) {
|
||||
auto mask = samples.index({ "...", -1 }) == value;
|
||||
auto first_dataset = samples.index({ mask, index_first });
|
||||
auto second_dataset = samples.index({ mask, index_second });
|
||||
auto mi = mutualInformation(first_dataset, second_dataset);
|
||||
auto pb = margin[value].item<float>();
|
||||
accumulated += pb * mi;
|
||||
}
|
||||
result.push_back(accumulated);
|
||||
}
|
||||
long n_vars = source.size();
|
||||
auto matrix = torch::zeros({ n_vars, n_vars });
|
||||
auto indices = torch::triu_indices(n_vars, n_vars, 1);
|
||||
for (auto i = 0; i < result.size(); ++i) {
|
||||
auto x = indices[0][i];
|
||||
auto y = indices[1][i];
|
||||
matrix[x][y] = result[i];
|
||||
matrix[y][x] = result[i];
|
||||
}
|
||||
return matrix;
|
||||
}
|
||||
double Network::entropy(torch::Tensor& feature)
|
||||
{
|
||||
torch::Tensor counts = feature.bincount();
|
||||
int totalWeight = counts.sum().item<int>();
|
||||
torch::Tensor probs = counts.to(torch::kFloat) / totalWeight;
|
||||
torch::Tensor logProbs = torch::log(probs);
|
||||
torch::Tensor entropy = -probs * logProbs;
|
||||
return entropy.nansum().item<double>();
|
||||
}
|
||||
// H(Y|X) = sum_{x in X} p(x) H(Y|X=x)
|
||||
double Network::conditionalEntropy(torch::Tensor& firstFeature, torch::Tensor& secondFeature)
|
||||
{
|
||||
int numSamples = firstFeature.sizes()[0];
|
||||
torch::Tensor featureCounts = secondFeature.bincount();
|
||||
unordered_map<int, unordered_map<int, double>> jointCounts;
|
||||
double totalWeight = 0;
|
||||
for (auto i = 0; i < numSamples; i++) {
|
||||
jointCounts[secondFeature[i].item<int>()][firstFeature[i].item<int>()] += 1;
|
||||
totalWeight += 1;
|
||||
}
|
||||
if (totalWeight == 0)
|
||||
throw invalid_argument("Total weight should not be zero");
|
||||
double entropyValue = 0;
|
||||
for (int value = 0; value < featureCounts.sizes()[0]; ++value) {
|
||||
double p_f = featureCounts[value].item<double>() / totalWeight;
|
||||
double entropy_f = 0;
|
||||
for (auto& [label, jointCount] : jointCounts[value]) {
|
||||
double p_l_f = jointCount / featureCounts[value].item<double>();
|
||||
if (p_l_f > 0) {
|
||||
entropy_f -= p_l_f * log(p_l_f);
|
||||
} else {
|
||||
entropy_f = 0;
|
||||
}
|
||||
}
|
||||
entropyValue += p_f * entropy_f;
|
||||
}
|
||||
return entropyValue;
|
||||
}
|
||||
// I(X;Y) = H(Y) - H(Y|X)
|
||||
double Network::mutualInformation(torch::Tensor& firstFeature, torch::Tensor& secondFeature)
|
||||
{
|
||||
return entropy(firstFeature) - conditionalEntropy(firstFeature, secondFeature);
|
||||
}
|
||||
}
|
||||
|
@ -19,7 +19,12 @@ namespace bayesnet {
|
||||
vector<double> predict_sample(const vector<int>&);
|
||||
vector<double> exactInference(map<string, int>&);
|
||||
double computeFactor(map<string, int>&);
|
||||
double mutual_info(torch::Tensor&, torch::Tensor&);
|
||||
double entropy(torch::Tensor&);
|
||||
double conditionalEntropy(torch::Tensor&, torch::Tensor&);
|
||||
double mutualInformation(torch::Tensor&, torch::Tensor&);
|
||||
public:
|
||||
torch::Tensor samples;
|
||||
Network();
|
||||
Network(float, int);
|
||||
Network(float);
|
||||
@ -35,6 +40,8 @@ namespace bayesnet {
|
||||
string getClassName();
|
||||
void fit(const vector<vector<int>>&, const vector<int>&, const vector<string>&, const string&);
|
||||
vector<int> predict(const vector<vector<int>>&);
|
||||
//Computes the conditional edge weight of variable index u and v conditioned on class_node
|
||||
torch::Tensor conditionalEdgeWeight();
|
||||
vector<vector<double>> predict_proba(const vector<vector<int>>&);
|
||||
double score(const vector<vector<int>>&, const vector<int>&);
|
||||
inline string version() { return "0.1.0"; }
|
||||
|
12
src/Node.cc
12
src/Node.cc
@ -57,23 +57,23 @@ namespace bayesnet {
|
||||
*/
|
||||
unsigned Node::minFill()
|
||||
{
|
||||
set<string> neighbors;
|
||||
unordered_set<string> neighbors;
|
||||
for (auto child : children) {
|
||||
neighbors.emplace(child->getName());
|
||||
}
|
||||
for (auto parent : parents) {
|
||||
neighbors.emplace(parent->getName());
|
||||
}
|
||||
return combinations(neighbors).size();
|
||||
auto source = vector<string>(neighbors.begin(), neighbors.end());
|
||||
return combinations(source).size();
|
||||
}
|
||||
vector<string> Node::combinations(const set<string>& neighbors)
|
||||
vector<pair<string, string>> Node::combinations(const vector<string>& source)
|
||||
{
|
||||
vector<string> source(neighbors.begin(), neighbors.end());
|
||||
vector<string> result;
|
||||
vector<pair<string, string>> result;
|
||||
for (int i = 0; i < source.size(); ++i) {
|
||||
string temp = source[i];
|
||||
for (int j = i + 1; j < source.size(); ++j) {
|
||||
result.push_back(temp + source[j]);
|
||||
result.push_back({ temp, source[j] });
|
||||
}
|
||||
}
|
||||
return result;
|
||||
|
@ -1,6 +1,7 @@
|
||||
#ifndef NODE_H
|
||||
#define NODE_H
|
||||
#include <torch/torch.h>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
namespace bayesnet {
|
||||
@ -13,8 +14,8 @@ namespace bayesnet {
|
||||
int numStates; // number of states of the variable
|
||||
torch::Tensor cpTable; // Order of indices is 0-> node variable, 1-> 1st parent, 2-> 2nd parent, ...
|
||||
vector<int64_t> dimensions; // dimensions of the cpTable
|
||||
vector<string> combinations(const set<string>&);
|
||||
public:
|
||||
vector<pair<string, string>> combinations(const vector<string>&);
|
||||
Node(const std::string&, int);
|
||||
void addParent(Node*);
|
||||
void addChild(Node*);
|
||||
|
Loading…
Reference in New Issue
Block a user