diff --git a/Untitled.ipynb b/Untitled.ipynb
new file mode 100644
index 0000000..69ab701
--- /dev/null
+++ b/Untitled.ipynb
@@ -0,0 +1,526 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "0e48f7d2-7481-4eca-9c38-56d21c203093",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "DEBUG:weka.core.jvm:Adding bundled jars\n",
+ "DEBUG:weka.core.jvm:Classpath=['/Users/rmontanana/miniconda3/envs/pyweka/lib/python3.10/site-packages/javabridge/jars/rhino-1.7R4.jar', '/Users/rmontanana/miniconda3/envs/pyweka/lib/python3.10/site-packages/javabridge/jars/runnablequeue.jar', '/Users/rmontanana/miniconda3/envs/pyweka/lib/python3.10/site-packages/javabridge/jars/cpython.jar', '/Users/rmontanana/miniconda3/envs/pyweka/lib/python3.10/site-packages/weka/lib/python-weka-wrapper.jar', '/Users/rmontanana/miniconda3/envs/pyweka/lib/python3.10/site-packages/weka/lib/weka.jar']\n",
+ "DEBUG:weka.core.jvm:MaxHeapSize=default\n",
+ "DEBUG:weka.core.jvm:Package support disabled\n",
+ "WARNING: An illegal reflective access operation has occurred\n",
+ "WARNING: Illegal reflective access by weka.core.WekaPackageClassLoaderManager (file:/Users/rmontanana/miniconda3/envs/pyweka/lib/python3.10/site-packages/weka/lib/weka.jar) to method java.lang.ClassLoader.defineClass(java.lang.String,byte[],int,int,java.security.ProtectionDomain)\n",
+ "WARNING: Please consider reporting this to the maintainers of weka.core.WekaPackageClassLoaderManager\n",
+ "WARNING: Use --illegal-access=warn to enable warnings of further illegal reflective access operations\n",
+ "WARNING: All illegal access operations will be denied in a future release\n"
+ ]
+ }
+ ],
+ "source": [
+ "import weka.core.jvm as jvm\n",
+ "jvm.start()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "2ac4e479-3818-4562-a967-bb303d8dd573",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from weka.core.converters import Loader\n",
+ "data_dir = \"/Users/rmontanana/Code/discretizbench/datasets/\"\n",
+ "loader = Loader(classname=\"weka.core.converters.ArffLoader\")\n",
+ "data = loader.load_file(data_dir + \"iris.arff\")\n",
+ "data.class_is_last()\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "ceb9f912-db42-4cbc-808f-48b5a9d89d44",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "@relation iris\n",
+ "\n",
+ "@attribute sepallength numeric\n",
+ "@attribute sepalwidth numeric\n",
+ "@attribute petallength numeric\n",
+ "@attribute petalwidth numeric\n",
+ "@attribute class {Iris-setosa,Iris-versicolor,Iris-virginica}\n",
+ "\n",
+ "@data\n",
+ "5.1,3.5,1.4,0.2,Iris-setosa\n",
+ "4.9,3,1.4,0.2,Iris-setosa\n",
+ "4.7,3.2,1.3,0.2,Iris-setosa\n",
+ "4.6,3.1,1.5,0.2,Iris-setosa\n",
+ "5,3.6,1.4,0.2,Iris-setosa\n",
+ "5.4,3.9,1.7,0.4,Iris-setosa\n",
+ "4.6,3.4,1.4,0.3,Iris-setosa\n",
+ "5,3.4,1.5,0.2,Iris-setosa\n",
+ "4.4,2.9,1.4,0.2,Iris-setosa\n",
+ "4.9,3.1,1.5,0.1,Iris-setosa\n",
+ "5.4,3.7,1.5,0.2,Iris-setosa\n",
+ "4.8,3.4,1.6,0.2,Iris-setosa\n",
+ "4.8,3,1.4,0.1,Iris-setosa\n",
+ "4.3,3,1.1,0.1,Iris-setosa\n",
+ "5.8,4,1.2,0.2,Iris-setosa\n",
+ "5.7,4.4,1.5,0.4,Iris-setosa\n",
+ "5.4,3.9,1.3,0.4,Iris-setosa\n",
+ "5.1,3.5,1.4,0.3,Iris-setosa\n",
+ "5.7,3.8,1.7,0.3,Iris-setosa\n",
+ "5.1,3.8,1.5,0.3,Iris-setosa\n",
+ "5.4,3.4,1.7,0.2,Iris-setosa\n",
+ "5.1,3.7,1.5,0.4,Iris-setosa\n",
+ "4.6,3.6,1,0.2,Iris-setosa\n",
+ "5.1,3.3,1.7,0.5,Iris-setosa\n",
+ "4.8,3.4,1.9,0.2,Iris-setosa\n",
+ "5,3,1.6,0.2,Iris-setosa\n",
+ "5,3.4,1.6,0.4,Iris-setosa\n",
+ "5.2,3.5,1.5,0.2,Iris-setosa\n",
+ "5.2,3.4,1.4,0.2,Iris-setosa\n",
+ "4.7,3.2,1.6,0.2,Iris-setosa\n",
+ "4.8,3.1,1.6,0.2,Iris-setosa\n",
+ "5.4,3.4,1.5,0.4,Iris-setosa\n",
+ "5.2,4.1,1.5,0.1,Iris-setosa\n",
+ "5.5,4.2,1.4,0.2,Iris-setosa\n",
+ "4.9,3.1,1.5,0.1,Iris-setosa\n",
+ "5,3.2,1.2,0.2,Iris-setosa\n",
+ "5.5,3.5,1.3,0.2,Iris-setosa\n",
+ "4.9,3.1,1.5,0.1,Iris-setosa\n",
+ "4.4,3,1.3,0.2,Iris-setosa\n",
+ "5.1,3.4,1.5,0.2,Iris-setosa\n",
+ "5,3.5,1.3,0.3,Iris-setosa\n",
+ "4.5,2.3,1.3,0.3,Iris-setosa\n",
+ "4.4,3.2,1.3,0.2,Iris-setosa\n",
+ "5,3.5,1.6,0.6,Iris-setosa\n",
+ "5.1,3.8,1.9,0.4,Iris-setosa\n",
+ "4.8,3,1.4,0.3,Iris-setosa\n",
+ "5.1,3.8,1.6,0.2,Iris-setosa\n",
+ "4.6,3.2,1.4,0.2,Iris-setosa\n",
+ "5.3,3.7,1.5,0.2,Iris-setosa\n",
+ "5,3.3,1.4,0.2,Iris-setosa\n",
+ "7,3.2,4.7,1.4,Iris-versicolor\n",
+ "6.4,3.2,4.5,1.5,Iris-versicolor\n",
+ "6.9,3.1,4.9,1.5,Iris-versicolor\n",
+ "5.5,2.3,4,1.3,Iris-versicolor\n",
+ "6.5,2.8,4.6,1.5,Iris-versicolor\n",
+ "5.7,2.8,4.5,1.3,Iris-versicolor\n",
+ "6.3,3.3,4.7,1.6,Iris-versicolor\n",
+ "4.9,2.4,3.3,1,Iris-versicolor\n",
+ "6.6,2.9,4.6,1.3,Iris-versicolor\n",
+ "5.2,2.7,3.9,1.4,Iris-versicolor\n",
+ "5,2,3.5,1,Iris-versicolor\n",
+ "5.9,3,4.2,1.5,Iris-versicolor\n",
+ "6,2.2,4,1,Iris-versicolor\n",
+ "6.1,2.9,4.7,1.4,Iris-versicolor\n",
+ "5.6,2.9,3.6,1.3,Iris-versicolor\n",
+ "6.7,3.1,4.4,1.4,Iris-versicolor\n",
+ "5.6,3,4.5,1.5,Iris-versicolor\n",
+ "5.8,2.7,4.1,1,Iris-versicolor\n",
+ "6.2,2.2,4.5,1.5,Iris-versicolor\n",
+ "5.6,2.5,3.9,1.1,Iris-versicolor\n",
+ "5.9,3.2,4.8,1.8,Iris-versicolor\n",
+ "6.1,2.8,4,1.3,Iris-versicolor\n",
+ "6.3,2.5,4.9,1.5,Iris-versicolor\n",
+ "6.1,2.8,4.7,1.2,Iris-versicolor\n",
+ "6.4,2.9,4.3,1.3,Iris-versicolor\n",
+ "6.6,3,4.4,1.4,Iris-versicolor\n",
+ "6.8,2.8,4.8,1.4,Iris-versicolor\n",
+ "6.7,3,5,1.7,Iris-versicolor\n",
+ "6,2.9,4.5,1.5,Iris-versicolor\n",
+ "5.7,2.6,3.5,1,Iris-versicolor\n",
+ "5.5,2.4,3.8,1.1,Iris-versicolor\n",
+ "5.5,2.4,3.7,1,Iris-versicolor\n",
+ "5.8,2.7,3.9,1.2,Iris-versicolor\n",
+ "6,2.7,5.1,1.6,Iris-versicolor\n",
+ "5.4,3,4.5,1.5,Iris-versicolor\n",
+ "6,3.4,4.5,1.6,Iris-versicolor\n",
+ "6.7,3.1,4.7,1.5,Iris-versicolor\n",
+ "6.3,2.3,4.4,1.3,Iris-versicolor\n",
+ "5.6,3,4.1,1.3,Iris-versicolor\n",
+ "5.5,2.5,4,1.3,Iris-versicolor\n",
+ "5.5,2.6,4.4,1.2,Iris-versicolor\n",
+ "6.1,3,4.6,1.4,Iris-versicolor\n",
+ "5.8,2.6,4,1.2,Iris-versicolor\n",
+ "5,2.3,3.3,1,Iris-versicolor\n",
+ "5.6,2.7,4.2,1.3,Iris-versicolor\n",
+ "5.7,3,4.2,1.2,Iris-versicolor\n",
+ "5.7,2.9,4.2,1.3,Iris-versicolor\n",
+ "6.2,2.9,4.3,1.3,Iris-versicolor\n",
+ "5.1,2.5,3,1.1,Iris-versicolor\n",
+ "5.7,2.8,4.1,1.3,Iris-versicolor\n",
+ "6.3,3.3,6,2.5,Iris-virginica\n",
+ "5.8,2.7,5.1,1.9,Iris-virginica\n",
+ "7.1,3,5.9,2.1,Iris-virginica\n",
+ "6.3,2.9,5.6,1.8,Iris-virginica\n",
+ "6.5,3,5.8,2.2,Iris-virginica\n",
+ "7.6,3,6.6,2.1,Iris-virginica\n",
+ "4.9,2.5,4.5,1.7,Iris-virginica\n",
+ "7.3,2.9,6.3,1.8,Iris-virginica\n",
+ "6.7,2.5,5.8,1.8,Iris-virginica\n",
+ "7.2,3.6,6.1,2.5,Iris-virginica\n",
+ "6.5,3.2,5.1,2,Iris-virginica\n",
+ "6.4,2.7,5.3,1.9,Iris-virginica\n",
+ "6.8,3,5.5,2.1,Iris-virginica\n",
+ "5.7,2.5,5,2,Iris-virginica\n",
+ "5.8,2.8,5.1,2.4,Iris-virginica\n",
+ "6.4,3.2,5.3,2.3,Iris-virginica\n",
+ "6.5,3,5.5,1.8,Iris-virginica\n",
+ "7.7,3.8,6.7,2.2,Iris-virginica\n",
+ "7.7,2.6,6.9,2.3,Iris-virginica\n",
+ "6,2.2,5,1.5,Iris-virginica\n",
+ "6.9,3.2,5.7,2.3,Iris-virginica\n",
+ "5.6,2.8,4.9,2,Iris-virginica\n",
+ "7.7,2.8,6.7,2,Iris-virginica\n",
+ "6.3,2.7,4.9,1.8,Iris-virginica\n",
+ "6.7,3.3,5.7,2.1,Iris-virginica\n",
+ "7.2,3.2,6,1.8,Iris-virginica\n",
+ "6.2,2.8,4.8,1.8,Iris-virginica\n",
+ "6.1,3,4.9,1.8,Iris-virginica\n",
+ "6.4,2.8,5.6,2.1,Iris-virginica\n",
+ "7.2,3,5.8,1.6,Iris-virginica\n",
+ "7.4,2.8,6.1,1.9,Iris-virginica\n",
+ "7.9,3.8,6.4,2,Iris-virginica\n",
+ "6.4,2.8,5.6,2.2,Iris-virginica\n",
+ "6.3,2.8,5.1,1.5,Iris-virginica\n",
+ "6.1,2.6,5.6,1.4,Iris-virginica\n",
+ "7.7,3,6.1,2.3,Iris-virginica\n",
+ "6.3,3.4,5.6,2.4,Iris-virginica\n",
+ "6.4,3.1,5.5,1.8,Iris-virginica\n",
+ "6,3,4.8,1.8,Iris-virginica\n",
+ "6.9,3.1,5.4,2.1,Iris-virginica\n",
+ "6.7,3.1,5.6,2.4,Iris-virginica\n",
+ "6.9,3.1,5.1,2.3,Iris-virginica\n",
+ "5.8,2.7,5.1,1.9,Iris-virginica\n",
+ "6.8,3.2,5.9,2.3,Iris-virginica\n",
+ "6.7,3.3,5.7,2.5,Iris-virginica\n",
+ "6.7,3,5.2,2.3,Iris-virginica\n",
+ "6.3,2.5,5,1.9,Iris-virginica\n",
+ "6.5,3,5.2,2,Iris-virginica\n",
+ "6.2,3.4,5.4,2.3,Iris-virginica\n",
+ "5.9,3,5.1,1.8,Iris-virginica\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(data)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "ded59d25-c34c-4fb8-a35f-1162f1218414",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from weka.classifiers import Classifier\n",
+ "cls = Classifier(classname=\"weka.classifiers.trees.J48\", options=[\"-C\", \"0.3\"])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "4c82f2ae-4071-4571-9a19-433b98463143",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "['-C', '0.3', '-M', '2']\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(cls.options)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "4c5c7893-ebbe-407d-872c-fd0bf41f8dc8",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "weka.classifiers.trees.J48 -C 0.3 -M 2\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(cls.to_commandline())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "7b73c18d-e0b0-469d-8a60-03bae8e01128",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "1: label index=0.0, class distribution=[0.99688403 0.00188598 0.00122999]\n",
+ "2: label index=0.0, class distribution=[0.99487322 0.00310305 0.00202373]\n",
+ "3: label index=0.0, class distribution=[0.99487322 0.00310305 0.00202373]\n",
+ "4: label index=0.0, class distribution=[0.99487322 0.00310305 0.00202373]\n",
+ "5: label index=0.0, class distribution=[0.99688403 0.00188598 0.00122999]\n",
+ "6: label index=0.0, class distribution=[0.99688403 0.00188598 0.00122999]\n",
+ "7: label index=0.0, class distribution=[0.99688403 0.00188598 0.00122999]\n",
+ "8: label index=0.0, class distribution=[0.99688403 0.00188598 0.00122999]\n",
+ "9: label index=0.0, class distribution=[0.96326708 0.02223308 0.01449983]\n",
+ "10: label index=0.0, class distribution=[0.99487322 0.00310305 0.00202373]\n",
+ "11: label index=0.0, class distribution=[0.99688403 0.00188598 0.00122999]\n",
+ "12: label index=0.0, class distribution=[0.99688403 0.00188598 0.00122999]\n",
+ "13: label index=0.0, class distribution=[0.99487322 0.00310305 0.00202373]\n",
+ "14: label index=0.0, class distribution=[0.99487322 0.00310305 0.00202373]\n",
+ "15: label index=0.0, class distribution=[0.9382677 0.03162683 0.03010547]\n",
+ "16: label index=0.0, class distribution=[0.9382677 0.03162683 0.03010547]\n",
+ "17: label index=0.0, class distribution=[0.99688403 0.00188598 0.00122999]\n",
+ "18: label index=0.0, class distribution=[0.99688403 0.00188598 0.00122999]\n",
+ "19: label index=0.0, class distribution=[0.9382677 0.03162683 0.03010547]\n",
+ "20: label index=0.0, class distribution=[0.99688403 0.00188598 0.00122999]\n",
+ "21: label index=0.0, class distribution=[0.99688403 0.00188598 0.00122999]\n",
+ "22: label index=0.0, class distribution=[0.99688403 0.00188598 0.00122999]\n",
+ "23: label index=0.0, class distribution=[0.99688403 0.00188598 0.00122999]\n",
+ "24: label index=0.0, class distribution=[0.99487322 0.00310305 0.00202373]\n",
+ "25: label index=0.0, class distribution=[0.99688403 0.00188598 0.00122999]\n",
+ "26: label index=0.0, class distribution=[0.99487322 0.00310305 0.00202373]\n",
+ "27: label index=0.0, class distribution=[0.99688403 0.00188598 0.00122999]\n",
+ "28: label index=0.0, class distribution=[0.99688403 0.00188598 0.00122999]\n",
+ "29: label index=0.0, class distribution=[0.99688403 0.00188598 0.00122999]\n",
+ "30: label index=0.0, class distribution=[0.99487322 0.00310305 0.00202373]\n",
+ "31: label index=0.0, class distribution=[0.99487322 0.00310305 0.00202373]\n",
+ "32: label index=0.0, class distribution=[0.99688403 0.00188598 0.00122999]\n",
+ "33: label index=0.0, class distribution=[0.99688403 0.00188598 0.00122999]\n",
+ "34: label index=0.0, class distribution=[0.99688403 0.00188598 0.00122999]\n",
+ "35: label index=0.0, class distribution=[0.99487322 0.00310305 0.00202373]\n",
+ "36: label index=0.0, class distribution=[0.99487322 0.00310305 0.00202373]\n",
+ "37: label index=0.0, class distribution=[0.99688403 0.00188598 0.00122999]\n",
+ "38: label index=0.0, class distribution=[0.99487322 0.00310305 0.00202373]\n",
+ "39: label index=0.0, class distribution=[0.99487322 0.00310305 0.00202373]\n",
+ "40: label index=0.0, class distribution=[0.99688403 0.00188598 0.00122999]\n",
+ "41: label index=0.0, class distribution=[0.99688403 0.00188598 0.00122999]\n",
+ "42: label index=0.0, class distribution=[0.96326708 0.02223308 0.01449983]\n",
+ "43: label index=0.0, class distribution=[0.99487322 0.00310305 0.00202373]\n",
+ "44: label index=0.0, class distribution=[0.99688403 0.00188598 0.00122999]\n",
+ "45: label index=0.0, class distribution=[0.99688403 0.00188598 0.00122999]\n",
+ "46: label index=0.0, class distribution=[0.99487322 0.00310305 0.00202373]\n",
+ "47: label index=0.0, class distribution=[0.99688403 0.00188598 0.00122999]\n",
+ "48: label index=0.0, class distribution=[0.99487322 0.00310305 0.00202373]\n",
+ "49: label index=0.0, class distribution=[0.99688403 0.00188598 0.00122999]\n",
+ "50: label index=0.0, class distribution=[0.99487322 0.00310305 0.00202373]\n",
+ "51: label index=1.0, class distribution=[0.00545355 0.97466198 0.01988447]\n",
+ "52: label index=1.0, class distribution=[0.00545355 0.97466198 0.01988447]\n",
+ "53: label index=1.0, class distribution=[0.010867 0.52425197 0.46488102]\n",
+ "54: label index=1.0, class distribution=[0.00725727 0.94287877 0.04986396]\n",
+ "55: label index=1.0, class distribution=[0.00228744 0.97269152 0.02502104]\n",
+ "56: label index=1.0, class distribution=[0.00308382 0.98338244 0.01353374]\n",
+ "57: label index=1.0, class distribution=[0.00545355 0.97466198 0.01988447]\n",
+ "58: label index=1.0, class distribution=[0.00725727 0.94287877 0.04986396]\n",
+ "59: label index=1.0, class distribution=[0.00228744 0.97269152 0.02502104]\n",
+ "60: label index=1.0, class distribution=[0.00725727 0.94287877 0.04986396]\n",
+ "61: label index=1.0, class distribution=[0.00725727 0.94287877 0.04986396]\n",
+ "62: label index=1.0, class distribution=[0.00732671 0.98195521 0.01071808]\n",
+ "63: label index=1.0, class distribution=[0.00308382 0.98338244 0.01353374]\n",
+ "64: label index=1.0, class distribution=[0.00308382 0.98338244 0.01353374]\n",
+ "65: label index=1.0, class distribution=[0.00308382 0.98338244 0.01353374]\n",
+ "66: label index=1.0, class distribution=[0.00545355 0.97466198 0.01988447]\n",
+ "67: label index=1.0, class distribution=[0.00732671 0.98195521 0.01071808]\n",
+ "68: label index=1.0, class distribution=[0.00308382 0.98338244 0.01353374]\n",
+ "69: label index=1.0, class distribution=[0.00228744 0.97269152 0.02502104]\n",
+ "70: label index=1.0, class distribution=[0.00308382 0.98338244 0.01353374]\n",
+ "71: label index=2.0, class distribution=[0.00920087 0.06127297 0.92952615]\n",
+ "72: label index=1.0, class distribution=[0.00308382 0.98338244 0.01353374]\n",
+ "73: label index=2.0, class distribution=[0.00409632 0.47019227 0.5257114 ]\n",
+ "74: label index=1.0, class distribution=[0.00308382 0.98338244 0.01353374]\n",
+ "75: label index=1.0, class distribution=[0.00228744 0.97269152 0.02502104]\n",
+ "76: label index=1.0, class distribution=[0.00545355 0.97466198 0.01988447]\n",
+ "77: label index=2.0, class distribution=[0.00409632 0.47019227 0.5257114 ]\n",
+ "78: label index=1.0, class distribution=[0.010867 0.52425197 0.46488102]\n",
+ "79: label index=1.0, class distribution=[0.00308382 0.98338244 0.01353374]\n",
+ "80: label index=1.0, class distribution=[0.00308382 0.98338244 0.01353374]\n",
+ "81: label index=1.0, class distribution=[0.00725727 0.94287877 0.04986396]\n",
+ "82: label index=1.0, class distribution=[0.00725727 0.94287877 0.04986396]\n",
+ "83: label index=1.0, class distribution=[0.00308382 0.98338244 0.01353374]\n",
+ "84: label index=1.0, class distribution=[0.02353491 0.65433551 0.32212958]\n",
+ "85: label index=1.0, class distribution=[0.01727259 0.943168 0.03955941]\n",
+ "86: label index=1.0, class distribution=[0.06513736 0.90310001 0.03176263]\n",
+ "87: label index=1.0, class distribution=[0.00545355 0.97466198 0.01988447]\n",
+ "88: label index=1.0, class distribution=[0.00228744 0.97269152 0.02502104]\n",
+ "89: label index=1.0, class distribution=[0.00732671 0.98195521 0.01071808]\n",
+ "90: label index=1.0, class distribution=[0.00725727 0.94287877 0.04986396]\n",
+ "91: label index=1.0, class distribution=[0.00725727 0.94287877 0.04986396]\n",
+ "92: label index=1.0, class distribution=[0.00732671 0.98195521 0.01071808]\n",
+ "93: label index=1.0, class distribution=[0.00308382 0.98338244 0.01353374]\n",
+ "94: label index=1.0, class distribution=[0.00725727 0.94287877 0.04986396]\n",
+ "95: label index=1.0, class distribution=[0.00308382 0.98338244 0.01353374]\n",
+ "96: label index=1.0, class distribution=[0.00732671 0.98195521 0.01071808]\n",
+ "97: label index=1.0, class distribution=[0.00308382 0.98338244 0.01353374]\n",
+ "98: label index=1.0, class distribution=[0.00228744 0.97269152 0.02502104]\n",
+ "99: label index=1.0, class distribution=[0.00725727 0.94287877 0.04986396]\n",
+ "100: label index=1.0, class distribution=[0.00308382 0.98338244 0.01353374]\n",
+ "101: label index=2.0, class distribution=[0.00102485 0.02817698 0.97079816]\n",
+ "102: label index=2.0, class distribution=[0.01274667 0.02829538 0.95895795]\n",
+ "103: label index=2.0, class distribution=[0.00102485 0.02817698 0.97079816]\n",
+ "104: label index=2.0, class distribution=[0.00139749 0.01280739 0.98579512]\n",
+ "105: label index=2.0, class distribution=[0.00102485 0.02817698 0.97079816]\n",
+ "106: label index=2.0, class distribution=[0.00102485 0.02817698 0.97079816]\n",
+ "107: label index=1.0, class distribution=[0.00725727 0.94287877 0.04986396]\n",
+ "108: label index=2.0, class distribution=[0.00139749 0.01280739 0.98579512]\n",
+ "109: label index=2.0, class distribution=[0.00139749 0.01280739 0.98579512]\n",
+ "110: label index=2.0, class distribution=[0.00431289 0.0395258 0.95616131]\n",
+ "111: label index=2.0, class distribution=[0.00102485 0.02817698 0.97079816]\n",
+ "112: label index=2.0, class distribution=[0.00139749 0.01280739 0.98579512]\n",
+ "113: label index=2.0, class distribution=[0.00102485 0.02817698 0.97079816]\n",
+ "114: label index=2.0, class distribution=[0.01274667 0.02829538 0.95895795]\n",
+ "115: label index=2.0, class distribution=[0.01274667 0.02829538 0.95895795]\n",
+ "116: label index=2.0, class distribution=[0.00102485 0.02817698 0.97079816]\n",
+ "117: label index=2.0, class distribution=[0.00102485 0.02817698 0.97079816]\n",
+ "118: label index=2.0, class distribution=[0.00431289 0.0395258 0.95616131]\n",
+ "119: label index=2.0, class distribution=[0.00139749 0.01280739 0.98579512]\n",
+ "120: label index=1.0, class distribution=[0.02353491 0.65433551 0.32212958]\n",
+ "121: label index=2.0, class distribution=[0.00102485 0.02817698 0.97079816]\n",
+ "122: label index=2.0, class distribution=[0.01274667 0.02829538 0.95895795]\n",
+ "123: label index=2.0, class distribution=[0.00139749 0.01280739 0.98579512]\n",
+ "124: label index=2.0, class distribution=[0.00139749 0.01280739 0.98579512]\n",
+ "125: label index=2.0, class distribution=[0.00102485 0.02817698 0.97079816]\n",
+ "126: label index=2.0, class distribution=[0.00102485 0.02817698 0.97079816]\n",
+ "127: label index=2.0, class distribution=[0.00139749 0.01280739 0.98579512]\n",
+ "128: label index=2.0, class distribution=[0.00920087 0.06127297 0.92952615]\n",
+ "129: label index=2.0, class distribution=[0.00139749 0.01280739 0.98579512]\n",
+ "130: label index=1.0, class distribution=[0.010867 0.52425197 0.46488102]\n",
+ "131: label index=2.0, class distribution=[0.00139749 0.01280739 0.98579512]\n",
+ "132: label index=2.0, class distribution=[0.00431289 0.0395258 0.95616131]\n",
+ "133: label index=2.0, class distribution=[0.00139749 0.01280739 0.98579512]\n",
+ "134: label index=2.0, class distribution=[0.00409632 0.47019227 0.5257114 ]\n",
+ "135: label index=1.0, class distribution=[0.02353491 0.65433551 0.32212958]\n",
+ "136: label index=2.0, class distribution=[0.00102485 0.02817698 0.97079816]\n",
+ "137: label index=2.0, class distribution=[0.00431289 0.0395258 0.95616131]\n",
+ "138: label index=2.0, class distribution=[0.00102485 0.02817698 0.97079816]\n",
+ "139: label index=2.0, class distribution=[0.00920087 0.06127297 0.92952615]\n",
+ "140: label index=2.0, class distribution=[0.00102485 0.02817698 0.97079816]\n",
+ "141: label index=2.0, class distribution=[0.00102485 0.02817698 0.97079816]\n",
+ "142: label index=2.0, class distribution=[0.00102485 0.02817698 0.97079816]\n",
+ "143: label index=2.0, class distribution=[0.01274667 0.02829538 0.95895795]\n",
+ "144: label index=2.0, class distribution=[0.00102485 0.02817698 0.97079816]\n",
+ "145: label index=2.0, class distribution=[0.00102485 0.02817698 0.97079816]\n",
+ "146: label index=2.0, class distribution=[0.00102485 0.02817698 0.97079816]\n",
+ "147: label index=2.0, class distribution=[0.00139749 0.01280739 0.98579512]\n",
+ "148: label index=2.0, class distribution=[0.00102485 0.02817698 0.97079816]\n",
+ "149: label index=2.0, class distribution=[0.00431289 0.0395258 0.95616131]\n",
+ "150: label index=2.0, class distribution=[0.00920087 0.06127297 0.92952615]\n"
+ ]
+ }
+ ],
+ "source": [
+ "from weka.classifiers import Classifier\n",
+ "cls = Classifier(classname=\"weka.classifiers.bayes.BayesNet\", options=[\"-Q\", \"weka.classifiers.bayes.net.search.local.TAN\"])\n",
+ "cls.build_classifier(data)\n",
+ "\n",
+ "for index, inst in enumerate(data):\n",
+ " pred = cls.classify_instance(inst)\n",
+ " dist = cls.distribution_for_instance(inst)\n",
+ " print(str(index+1) + \": label index=\" + str(pred) + \", class distribution=\" + str(dist))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "0b74f00a-15b3-4177-bb8c-e02ed1a3fd38",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Bayes Network Classifier\n",
+ "Using ADTree\n",
+ "#attributes=5 #classindex=4\n",
+ "Network structure (nodes followed by parents)\n",
+ "sepallength(3): class \n",
+ "sepalwidth(3): class petalwidth \n",
+ "petallength(3): class sepallength \n",
+ "petalwidth(3): class petallength \n",
+ "class(3): \n",
+ "LogScore Bayes: -484.0749140715054\n",
+ "LogScore BDeu: -653.8524681760015\n",
+ "LogScore MDL: -654.6252712234647\n",
+ "LogScore ENTROPY: -499.2955771064808\n",
+ "LogScore AIC: -561.2955771064808\n",
+ "\n"
+ ]
+ },
+ {
+ "ename": "OSError",
+ "evalue": "[Errno 63] File name too long: '\\n\\n\\n\\t \\n\\t\\n\\t\\n\\t\\n\\t \\n\\t\\n\\t\\n\\t\\n\\t\\n\\t\\n\\t\\n]>\\n\\n\\n\\n\\niris-weka.filters.supervised.attribute.Discretize-Rfirst-last-precision6-weka.filters.unsupervised.attribute.ReplaceMissingValues\\n\\nsepallength\\n'\\\\'(-inf-5.55]\\\\''\\n'\\\\'(5.55-6.15]\\\\''\\n'\\\\'(6.15-inf)\\\\''\\n\\n\\nsepalwidth\\n'\\\\'(-inf-2.95]\\\\''\\n'\\\\'(2.95-3.35]\\\\''\\n'\\\\'(3.35-inf)\\\\''\\n\\n\\npetallength\\n'\\\\'(-inf-2.45]\\\\''\\n'\\\\'(2.45-4.75]\\\\''\\n'\\\\'(4.75-inf)\\\\''\\n\\n\\npetalwidth\\n'\\\\'(-inf-0.8]\\\\''\\n'\\\\'(0.8-1.75]\\\\''\\n'\\\\'(1.75-inf)\\\\''\\n\\n\\nclass\\nIris-setosa\\nIris-versicolor\\nIris-virginica\\n\\n\\nsepallength\\nclass\\n\\n0.9223300970873787 0.06796116504854369 0.009708737864077669 \\n0.22330097087378642 0.4563106796116505 0.32038834951456313 \\n0.02912621359223301 0.20388349514563106 0.7669902912621359 \\n
\\n\\n\\nsepalwidth\\nclass\\npetalwidth\\n\\n0.04854368932038835 0.3592233009708738 0.5922330097087378 \\n0.3333333333333333 0.3333333333333333 0.3333333333333333 \\n0.3333333333333333 0.3333333333333333 0.3333333333333333 \\n0.3333333333333333 0.3333333333333333 0.3333333333333333 \\n0.6831683168316832 0.2871287128712871 0.0297029702970297 \\n0.2 0.6 0.2 \\n0.3333333333333333 0.3333333333333333 0.3333333333333333 \\n0.6923076923076923 0.23076923076923078 0.07692307692307693 \\n0.3763440860215054 0.5053763440860215 0.11827956989247312 \\n
\\n\\n\\npetallength\\nclass\\nsepallength\\n\\n0.979381443298969 0.010309278350515464 0.010309278350515464 \\n0.7777777777777778 0.1111111111111111 0.1111111111111111 \\n0.3333333333333333 0.3333333333333333 0.3333333333333333 \\n0.04 0.92 0.04 \\n0.02040816326530612 0.8775510204081632 0.10204081632653061 \\n0.02857142857142857 0.7142857142857143 0.2571428571428571 \\n0.2 0.6 0.2 \\n0.043478260869565216 0.043478260869565216 0.9130434782608695 \\n0.012345679012345678 0.012345679012345678 0.9753086419753086 \\n
\\n\\n\\npetalwidth\\nclass\\npetallength\\n\\n0.9805825242718447 0.009708737864077669 0.009708737864077669 \\n0.3333333333333333 0.3333333333333333 0.3333333333333333 \\n0.3333333333333333 0.3333333333333333 0.3333333333333333 \\n0.3333333333333333 0.3333333333333333 0.3333333333333333 \\n0.01098901098901099 0.978021978021978 0.01098901098901099 \\n0.06666666666666667 0.7333333333333333 0.2 \\n0.3333333333333333 0.3333333333333333 0.3333333333333333 \\n0.2 0.6 0.2 \\n0.009900990099009901 0.0891089108910891 0.900990099009901 \\n
\\n\\n\\nclass\\n\\n0.3333333333333333 0.3333333333333333 0.3333333333333333 \\n
\\n\\n\\n\\n'",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mOSError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn [13], line 9\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;28mcls\u001b[39m)\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mweka\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mplot\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mgraph\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mgraph\u001b[39;00m \u001b[38;5;66;03m# NB: pygraphviz and PIL are required\u001b[39;00m\n\u001b[0;32m----> 9\u001b[0m \u001b[43mgraph\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mplot_dot_graph\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgraph\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m~/miniconda3/envs/pyweka/lib/python3.10/site-packages/weka/plot/graph.py:49\u001b[0m, in \u001b[0;36mplot_dot_graph\u001b[0;34m(graph, filename)\u001b[0m\n\u001b[1;32m 46\u001b[0m logger\u001b[38;5;241m.\u001b[39merror(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPIL is not installed, cannot display graph plot!\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 47\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[0;32m---> 49\u001b[0m agraph \u001b[38;5;241m=\u001b[39m \u001b[43mAGraph\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgraph\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 50\u001b[0m agraph\u001b[38;5;241m.\u001b[39mlayout(prog\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdot\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 51\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m filename \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
+ "File \u001b[0;32m~/miniconda3/envs/pyweka/lib/python3.10/site-packages/pygraphviz/agraph.py:157\u001b[0m, in \u001b[0;36mAGraph.__init__\u001b[0;34m(self, thing, filename, data, string, handle, name, strict, directed, **attr)\u001b[0m\n\u001b[1;32m 154\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_owns_handle \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[1;32m 155\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m filename \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 156\u001b[0m \u001b[38;5;66;03m# load new graph from file (creates self.handle)\u001b[39;00m\n\u001b[0;32m--> 157\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mread\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfilename\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 158\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m string \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 159\u001b[0m \u001b[38;5;66;03m# load new graph from string (creates self.handle)\u001b[39;00m\n\u001b[1;32m 160\u001b[0m \u001b[38;5;66;03m# get the charset from the string to properly encode it for\u001b[39;00m\n\u001b[1;32m 161\u001b[0m \u001b[38;5;66;03m# writing to the temporary file in from_string()\u001b[39;00m\n\u001b[1;32m 162\u001b[0m match \u001b[38;5;241m=\u001b[39m re\u001b[38;5;241m.\u001b[39msearch(\u001b[38;5;124mr\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcharset\u001b[39m\u001b[38;5;124m\\\u001b[39m\u001b[38;5;124ms*=\u001b[39m\u001b[38;5;124m\\\u001b[39m\u001b[38;5;124ms*\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m([^\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m]+)\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m, string)\n",
+ "File \u001b[0;32m~/miniconda3/envs/pyweka/lib/python3.10/site-packages/pygraphviz/agraph.py:1243\u001b[0m, in \u001b[0;36mAGraph.read\u001b[0;34m(self, path)\u001b[0m\n\u001b[1;32m 1233\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mread\u001b[39m(\u001b[38;5;28mself\u001b[39m, path):\n\u001b[1;32m 1234\u001b[0m \u001b[38;5;124;03m\"\"\"Read graph from dot format file on path.\u001b[39;00m\n\u001b[1;32m 1235\u001b[0m \n\u001b[1;32m 1236\u001b[0m \u001b[38;5;124;03m path can be a file name or file handle\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1241\u001b[0m \n\u001b[1;32m 1242\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 1243\u001b[0m fh \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_get_fh\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpath\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1244\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1245\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_close_handle()\n",
+ "File \u001b[0;32m~/miniconda3/envs/pyweka/lib/python3.10/site-packages/pygraphviz/agraph.py:1791\u001b[0m, in \u001b[0;36mAGraph._get_fh\u001b[0;34m(self, path, mode)\u001b[0m\n\u001b[1;32m 1789\u001b[0m fh \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39mpopen(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbzcat \u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m+\u001b[39m path) \u001b[38;5;66;03m# probably not portable\u001b[39;00m\n\u001b[1;32m 1790\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1791\u001b[0m fh \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmode\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1792\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(path, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwrite\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 1793\u001b[0m \u001b[38;5;66;03m# Note, mode of file handle is unchanged.\u001b[39;00m\n\u001b[1;32m 1794\u001b[0m fh \u001b[38;5;241m=\u001b[39m path\n",
+ "\u001b[0;31mOSError\u001b[0m: [Errno 63] File name too long: '\\n\\n\\n\\t \\n\\t\\n\\t\\n\\t\\n\\t \\n\\t\\n\\t\\n\\t\\n\\t\\n\\t\\n\\t\\n]>\\n\\n\\n\\n\\niris-weka.filters.supervised.attribute.Discretize-Rfirst-last-precision6-weka.filters.unsupervised.attribute.ReplaceMissingValues\\n\\nsepallength\\n'\\\\'(-inf-5.55]\\\\''\\n'\\\\'(5.55-6.15]\\\\''\\n'\\\\'(6.15-inf)\\\\''\\n\\n\\nsepalwidth\\n'\\\\'(-inf-2.95]\\\\''\\n'\\\\'(2.95-3.35]\\\\''\\n'\\\\'(3.35-inf)\\\\''\\n\\n\\npetallength\\n'\\\\'(-inf-2.45]\\\\''\\n'\\\\'(2.45-4.75]\\\\''\\n'\\\\'(4.75-inf)\\\\''\\n\\n\\npetalwidth\\n'\\\\'(-inf-0.8]\\\\''\\n'\\\\'(0.8-1.75]\\\\''\\n'\\\\'(1.75-inf)\\\\''\\n\\n\\nclass\\nIris-setosa\\nIris-versicolor\\nIris-virginica\\n\\n\\nsepallength\\nclass\\n\\n0.9223300970873787 0.06796116504854369 0.009708737864077669 \\n0.22330097087378642 0.4563106796116505 0.32038834951456313 \\n0.02912621359223301 0.20388349514563106 0.7669902912621359 \\n
\\n\\n\\nsepalwidth\\nclass\\npetalwidth\\n\\n0.04854368932038835 0.3592233009708738 0.5922330097087378 \\n0.3333333333333333 0.3333333333333333 0.3333333333333333 \\n0.3333333333333333 0.3333333333333333 0.3333333333333333 \\n0.3333333333333333 0.3333333333333333 0.3333333333333333 \\n0.6831683168316832 0.2871287128712871 0.0297029702970297 \\n0.2 0.6 0.2 \\n0.3333333333333333 0.3333333333333333 0.3333333333333333 \\n0.6923076923076923 0.23076923076923078 0.07692307692307693 \\n0.3763440860215054 0.5053763440860215 0.11827956989247312 \\n
\\n\\n\\npetallength\\nclass\\nsepallength\\n\\n0.979381443298969 0.010309278350515464 0.010309278350515464 \\n0.7777777777777778 0.1111111111111111 0.1111111111111111 \\n0.3333333333333333 0.3333333333333333 0.3333333333333333 \\n0.04 0.92 0.04 \\n0.02040816326530612 0.8775510204081632 0.10204081632653061 \\n0.02857142857142857 0.7142857142857143 0.2571428571428571 \\n0.2 0.6 0.2 \\n0.043478260869565216 0.043478260869565216 0.9130434782608695 \\n0.012345679012345678 0.012345679012345678 0.9753086419753086 \\n
\\n\\n\\npetalwidth\\nclass\\npetallength\\n\\n0.9805825242718447 0.009708737864077669 0.009708737864077669 \\n0.3333333333333333 0.3333333333333333 0.3333333333333333 \\n0.3333333333333333 0.3333333333333333 0.3333333333333333 \\n0.3333333333333333 0.3333333333333333 0.3333333333333333 \\n0.01098901098901099 0.978021978021978 0.01098901098901099 \\n0.06666666666666667 0.7333333333333333 0.2 \\n0.3333333333333333 0.3333333333333333 0.3333333333333333 \\n0.2 0.6 0.2 \\n0.009900990099009901 0.0891089108910891 0.900990099009901 \\n
\\n\\n\\nclass\\n\\n0.3333333333333333 0.3333333333333333 0.3333333333333333 \\n
\\n\\n\\n\\n'"
+ ]
+ }
+ ],
+ "source": [
+ "from weka.classifiers import Classifier\n",
+ "\n",
+ "cls = Classifier(classname=\"weka.classifiers.bayes.BayesNet\", options=[\"-Q\", \"weka.classifiers.bayes.net.search.local.TAN\"])\n",
+ "cls.build_classifier(data)\n",
+ "\n",
+ "print(cls)\n",
+ "\n",
+ "import weka.plot.graph as graph # NB: pygraphviz and PIL are required\n",
+ "graph.plot_dot_graph(cls.graph)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "3f59f200-4f23-4add-86ae-6df1494ede82",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.6"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/benchmark/Experiments.py b/benchmark/Experiments.py
index 658805a..6e4b763 100644
--- a/benchmark/Experiments.py
+++ b/benchmark/Experiments.py
@@ -1,4 +1,5 @@
import os
+import sys
import json
import random
import warnings
@@ -162,6 +163,10 @@ class Experiment:
def get_output_file(self):
return self.output_file
+ @staticmethod
+ def get_python_version():
+ return "{}.{}".format(sys.version_info.major, sys.version_info.minor)
+
def _build_classifier(self, random_state, hyperparameters):
self.model = Models.get_model(self.model_name, random_state)
clf = self.model
@@ -193,7 +198,7 @@ class Experiment:
shuffle=True, random_state=random_state, n_splits=self.folds
)
clf = self._build_classifier(random_state, hyperparameters)
- self.version = clf.version() if hasattr(clf, "version") else "-"
+ self.version = Models.get_version(self.model_name, clf)
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
res = cross_validate(
@@ -243,6 +248,8 @@ class Experiment:
output["duration"] = self.duration
output["seeds"] = self.random_seeds
output["platform"] = self.platform
+ output["language_version"] = self.get_python_version()
+ output["language"] = "Python"
output["results"] = self.results
with open(self.output_file, "w") as f:
json.dump(output, f)
diff --git a/benchmark/Models.py b/benchmark/Models.py
index 9b363c3..03d31eb 100644
--- a/benchmark/Models.py
+++ b/benchmark/Models.py
@@ -11,6 +11,8 @@ from stree import Stree
from wodt import Wodt
from odte import Odte
from xgboost import XGBClassifier
+import sklearn
+import xgboost
class Models:
@@ -89,3 +91,15 @@ class Models:
nodes, leaves = result.nodes_leaves()
depth = result.depth_ if hasattr(result, "depth_") else 0
return nodes, leaves, depth
+
+ @staticmethod
+ def get_version(name, clf):
+ if hasattr(clf, "version"):
+ return clf.version()
+ if name in ["Cart", "ExtraTree", "RandomForest", "GBC", "SVC"]:
+ return sklearn.__version__
+ elif name.startswith("Bagging") or name.startswith("AdaBoost"):
+ return sklearn.__version__
+ elif name == "XGBoost":
+ return xgboost.__version__
+ return "Error"
diff --git a/benchmark/Results.py b/benchmark/Results.py
index bf0c8e6..b76434a 100644
--- a/benchmark/Results.py
+++ b/benchmark/Results.py
@@ -16,6 +16,7 @@ from .Utils import (
Symbols,
TextColor,
NO_RESULTS,
+ PYTHON_VERSION,
)
@@ -196,7 +197,8 @@ class Report(BaseReport):
self._compare_totals = {}
self.header_line("*")
self.header_line(
- f" Report {self.data['model']} ver. {self.data['version']}"
+ f" {self.data['model']} ver. {self.data['version']}"
+ f" {self.data['language']} ver. {self.data['language_version']}"
f" with {self.data['folds']} Folds "
f"cross validation and {len(self.data['seeds'])} random seeds. "
f"{self.data['date']} {self.data['time']}"
@@ -347,7 +349,8 @@ class Excel(BaseReport):
def get_title(self):
return (
- f" Report {self.data['model']} ver. {self.data['version']}"
+ f" {self.data['model']} ver. {self.data['version']}"
+ f" {self.data['language']} ver. {self.data['language_version']}"
f" with {self.data['folds']} Folds "
f"cross validation and {len(self.data['seeds'])} random seeds. "
f"{self.data['date']} {self.data['time']}"
diff --git a/benchmark/Utils.py b/benchmark/Utils.py
index d470959..b6b5797 100644
--- a/benchmark/Utils.py
+++ b/benchmark/Utils.py
@@ -1,6 +1,8 @@
import os
+import sys
import subprocess
+PYTHON_VERSION = "{}.{}".format(sys.version_info.major, sys.version_info.minor)
NO_RESULTS = "** No results found **"
NO_ENV = "File .env not found"
diff --git a/benchmark/tests/Excel_test.py b/benchmark/tests/Excel_test.py
index 6ed24cf..226b85b 100644
--- a/benchmark/tests/Excel_test.py
+++ b/benchmark/tests/Excel_test.py
@@ -4,6 +4,7 @@ from xlsxwriter import Workbook
from .TestBase import TestBase
from ..Results import Excel
from ..Utils import Folders
+import benchmark.Utils
class ExcelTest(TestBase):
diff --git a/benchmark/tests/Models_test.py b/benchmark/tests/Models_test.py
index ea5b12a..911cc95 100644
--- a/benchmark/tests/Models_test.py
+++ b/benchmark/tests/Models_test.py
@@ -15,6 +15,8 @@ from odte import Odte
from xgboost import XGBClassifier
from .TestBase import TestBase
from ..Models import Models
+import xgboost
+import sklearn
class ModelTest(TestBase):
@@ -33,6 +35,38 @@ class ModelTest(TestBase):
for key, value in test.items():
self.assertIsInstance(Models.get_model(key), value)
+ def test_Models_version(self):
+ def ver_stree():
+ return "1.2.3"
+
+ def ver_wodt():
+ return "h.j.k"
+
+ def ver_odte():
+ return "4.5.6"
+
+ test = {
+ "STree": [ver_stree, "1.2.3"],
+ "Wodt": [ver_wodt, "h.j.k"],
+ "ODTE": [ver_odte, "4.5.6"],
+ "RandomForest": [None, "7.8.9"],
+ "BaggingStree": [None, "x.y.z"],
+ "AdaBoostStree": [None, "w.x.z"],
+ "XGBoost": [None, "10.11.12"],
+ }
+ for key, value in test.items():
+ clf = Models.get_model(key)
+ if key in ["STree", "Wodt", "ODTE"]:
+ clf.version = value[0]
+ elif key == "XGBoost":
+ xgboost.__version__ = value[1]
+ else:
+ sklearn.__version__ = value[1]
+ self.assertEqual(Models.get_version(key, clf), value[1])
+
+ def test_bogus_Model_Version(self):
+ self.assertEqual(Models.get_version("unknown", None), "Error")
+
def test_BaggingStree(self):
clf = Models.get_model("BaggingStree")
self.assertIsInstance(clf, BaggingClassifier)
diff --git a/benchmark/tests/results/results_accuracy_ODTE_Galgo_2022-04-20_10:52:20_0.json b/benchmark/tests/results/results_accuracy_ODTE_Galgo_2022-04-20_10:52:20_0.json
index b2485f7..34063a0 100644
--- a/benchmark/tests/results/results_accuracy_ODTE_Galgo_2022-04-20_10:52:20_0.json
+++ b/benchmark/tests/results/results_accuracy_ODTE_Galgo_2022-04-20_10:52:20_0.json
@@ -3,6 +3,8 @@
"title": "Gridsearched hyperparams v022.1b random_init",
"model": "ODTE",
"version": "0.3.2",
+ "language_version": "3.11x",
+ "language": "Python",
"stratified": false,
"folds": 5,
"date": "2022-04-20",
diff --git a/benchmark/tests/results/results_accuracy_RandomForest_iMac27_2022-01-14_12:39:30_0.json b/benchmark/tests/results/results_accuracy_RandomForest_iMac27_2022-01-14_12:39:30_0.json
index ef522ed..e4e5490 100644
--- a/benchmark/tests/results/results_accuracy_RandomForest_iMac27_2022-01-14_12:39:30_0.json
+++ b/benchmark/tests/results/results_accuracy_RandomForest_iMac27_2022-01-14_12:39:30_0.json
@@ -3,6 +3,8 @@
"title": "Test default paramters with RandomForest",
"model": "RandomForest",
"version": "-",
+ "language_version": "3.11x",
+ "language": "Python",
"stratified": false,
"folds": 5,
"date": "2022-01-14",
diff --git a/benchmark/tests/results/results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json b/benchmark/tests/results/results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json
index 6704808..197b1f6 100644
--- a/benchmark/tests/results/results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json
+++ b/benchmark/tests/results/results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json
@@ -3,6 +3,8 @@
"model": "STree",
"stratified": false,
"folds": 5,
+ "language_version": "3.11x",
+ "language": "Python",
"date": "2021-09-30",
"time": "11:42:07",
"duration": 624.2505249977112,
diff --git a/benchmark/tests/results/results_accuracy_STree_iMac27_2021-10-27_09:40:40_0.json b/benchmark/tests/results/results_accuracy_STree_iMac27_2021-10-27_09:40:40_0.json
index e5fd58e..ee64f53 100644
--- a/benchmark/tests/results/results_accuracy_STree_iMac27_2021-10-27_09:40:40_0.json
+++ b/benchmark/tests/results/results_accuracy_STree_iMac27_2021-10-27_09:40:40_0.json
@@ -1,6 +1,8 @@
{
"score_name": "accuracy",
"model": "STree",
+ "language": "Python",
+ "language_version": "3.11x",
"stratified": false,
"folds": 5,
"date": "2021-10-27",
diff --git a/benchmark/tests/results/results_accuracy_STree_macbook-pro_2021-11-01_19:17:07_0.json b/benchmark/tests/results/results_accuracy_STree_macbook-pro_2021-11-01_19:17:07_0.json
index 1ca939e..3703250 100644
--- a/benchmark/tests/results/results_accuracy_STree_macbook-pro_2021-11-01_19:17:07_0.json
+++ b/benchmark/tests/results/results_accuracy_STree_macbook-pro_2021-11-01_19:17:07_0.json
@@ -1,6 +1,8 @@
{
"score_name": "accuracy",
"model": "STree",
+ "language_version": "3.11x",
+ "language": "Python",
"stratified": false,
"folds": 5,
"date": "2021-11-01",
diff --git a/benchmark/tests/test_files/be_main_best.test b/benchmark/tests/test_files/be_main_best.test
index 8b21255..adecb4c 100644
--- a/benchmark/tests/test_files/be_main_best.test
+++ b/benchmark/tests/test_files/be_main_best.test
@@ -1,5 +1,5 @@
[94m************************************************************************************************************************
-[94m* Report STree ver. 1.2.4 with 5 Folds cross validation and 10 random seeds. 2022-05-09 00:15:25 *
+[94m* STree ver. 1.2.4 Python ver. 3.11x with 5 Folds cross validation and 10 random seeds. 2022-05-09 00:15:25 *
[94m* test *
[94m* Random seeds: [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1] Stratified: False *
[94m* Execution took 0.80 seconds, 0.00 hours, on iMac27 *
diff --git a/benchmark/tests/test_files/be_main_complete.test b/benchmark/tests/test_files/be_main_complete.test
index 793d267..8f36fe6 100644
--- a/benchmark/tests/test_files/be_main_complete.test
+++ b/benchmark/tests/test_files/be_main_complete.test
@@ -1,5 +1,5 @@
[94m************************************************************************************************************************
-[94m* Report STree ver. 1.2.4 with 5 Folds cross validation and 10 random seeds. 2022-05-08 20:14:43 *
+[94m* STree ver. 1.2.4 Python ver. 3.11x with 5 Folds cross validation and 10 random seeds. 2022-05-08 20:14:43 *
[94m* test *
[94m* Random seeds: [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1] Stratified: False *
[94m* Execution took 0.48 seconds, 0.00 hours, on iMac27 *
diff --git a/benchmark/tests/test_files/be_main_dataset.test b/benchmark/tests/test_files/be_main_dataset.test
index abfcc76..46e50e3 100644
--- a/benchmark/tests/test_files/be_main_dataset.test
+++ b/benchmark/tests/test_files/be_main_dataset.test
@@ -1,5 +1,5 @@
[94m************************************************************************************************************************
-[94m* Report STree ver. 1.2.4 with 5 Folds cross validation and 10 random seeds. 2022-05-08 19:38:28 *
+[94m* STree ver. 1.2.4 Python ver. 3.11x with 5 Folds cross validation and 10 random seeds. 2022-05-08 19:38:28 *
[94m* test *
[94m* Random seeds: [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1] Stratified: False *
[94m* Execution took 0.06 seconds, 0.00 hours, on iMac27 *
diff --git a/benchmark/tests/test_files/be_main_grid.test b/benchmark/tests/test_files/be_main_grid.test
index a4bec6e..c29ec32 100644
--- a/benchmark/tests/test_files/be_main_grid.test
+++ b/benchmark/tests/test_files/be_main_grid.test
@@ -1,5 +1,5 @@
[94m************************************************************************************************************************
-[94m* Report STree ver. 1.2.4 with 5 Folds cross validation and 10 random seeds. 2022-05-09 00:21:06 *
+[94m* STree ver. 1.2.4 Python ver. 3.11x with 5 Folds cross validation and 10 random seeds. 2022-05-09 00:21:06 *
[94m* test *
[94m* Random seeds: [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1] Stratified: False *
[94m* Execution took 0.89 seconds, 0.00 hours, on iMac27 *
diff --git a/benchmark/tests/test_files/excel.test b/benchmark/tests/test_files/excel.test
index 373c803..dbb9675 100644
--- a/benchmark/tests/test_files/excel.test
+++ b/benchmark/tests/test_files/excel.test
@@ -1,4 +1,4 @@
-1;1;" Report STree ver. 1.2.3 with 5 Folds cross validation and 10 random seeds. 2021-09-30 11:42:07"
+1;1;" STree ver. 1.2.3 Python ver. 3.11x with 5 Folds cross validation and 10 random seeds. 2021-09-30 11:42:07"
2;1;" With gridsearched hyperparameters"
3;1;" Score is accuracy"
3;2;" Execution time"
diff --git a/benchmark/tests/test_files/excel_add_ODTE.test b/benchmark/tests/test_files/excel_add_ODTE.test
index 34f226f..55ac7bb 100644
--- a/benchmark/tests/test_files/excel_add_ODTE.test
+++ b/benchmark/tests/test_files/excel_add_ODTE.test
@@ -1,4 +1,4 @@
-1;1;" Report ODTE ver. 0.3.2 with 5 Folds cross validation and 10 random seeds. 2022-04-20 10:52:20"
+1;1;" ODTE ver. 0.3.2 Python ver. 3.11x with 5 Folds cross validation and 10 random seeds. 2022-04-20 10:52:20"
2;1;" Gridsearched hyperparams v022.1b random_init"
3;1;" Score is accuracy"
3;2;" Execution time"
diff --git a/benchmark/tests/test_files/excel_add_STree.test b/benchmark/tests/test_files/excel_add_STree.test
index 3a864e4..1e105f9 100644
--- a/benchmark/tests/test_files/excel_add_STree.test
+++ b/benchmark/tests/test_files/excel_add_STree.test
@@ -1,4 +1,4 @@
-1;1;" Report STree ver. 1.2.3 with 5 Folds cross validation and 10 random seeds. 2021-10-27 09:40:40"
+1;1;" STree ver. 1.2.3 Python ver. 3.11x with 5 Folds cross validation and 10 random seeds. 2021-10-27 09:40:40"
2;1;" default A"
3;1;" Score is accuracy"
3;2;" Execution time"
diff --git a/benchmark/tests/test_files/excel_compared.test b/benchmark/tests/test_files/excel_compared.test
index 16b415a..2b44c1e 100644
--- a/benchmark/tests/test_files/excel_compared.test
+++ b/benchmark/tests/test_files/excel_compared.test
@@ -1,4 +1,4 @@
-1;1;" Report STree ver. 1.2.3 with 5 Folds cross validation and 10 random seeds. 2021-09-30 11:42:07"
+1;1;" STree ver. 1.2.3 Python ver. 3.11x with 5 Folds cross validation and 10 random seeds. 2021-09-30 11:42:07"
2;1;" With gridsearched hyperparameters"
3;1;" Score is accuracy"
3;2;" Execution time"
diff --git a/benchmark/tests/test_files/exreport_excel_ODTE.test b/benchmark/tests/test_files/exreport_excel_ODTE.test
index 46188f4..f1ea096 100644
--- a/benchmark/tests/test_files/exreport_excel_ODTE.test
+++ b/benchmark/tests/test_files/exreport_excel_ODTE.test
@@ -1,4 +1,4 @@
-1;1;" Report ODTE ver. 0.3.2 with 5 Folds cross validation and 10 random seeds. 2022-04-20 10:52:20"
+1;1;" ODTE ver. 0.3.2 Python ver. 3.11x with 5 Folds cross validation and 10 random seeds. 2022-04-20 10:52:20"
2;1;" Gridsearched hyperparams v022.1b random_init"
3;1;" Score is accuracy"
3;2;" Execution time"
diff --git a/benchmark/tests/test_files/exreport_excel_RandomForest.test b/benchmark/tests/test_files/exreport_excel_RandomForest.test
index 7e7a395..aa608d0 100644
--- a/benchmark/tests/test_files/exreport_excel_RandomForest.test
+++ b/benchmark/tests/test_files/exreport_excel_RandomForest.test
@@ -1,4 +1,4 @@
-1;1;" Report RandomForest ver. - with 5 Folds cross validation and 10 random seeds. 2022-01-14 12:39:30"
+1;1;" RandomForest ver. - Python ver. 3.11x with 5 Folds cross validation and 10 random seeds. 2022-01-14 12:39:30"
2;1;" Test default paramters with RandomForest"
3;1;" Score is accuracy"
3;2;" Execution time"
diff --git a/benchmark/tests/test_files/exreport_excel_STree.test b/benchmark/tests/test_files/exreport_excel_STree.test
index 18b7aa4..6a164b5 100644
--- a/benchmark/tests/test_files/exreport_excel_STree.test
+++ b/benchmark/tests/test_files/exreport_excel_STree.test
@@ -1,4 +1,4 @@
-1;1;" Report STree ver. 1.2.3 with 5 Folds cross validation and 10 random seeds. 2021-09-30 11:42:07"
+1;1;" STree ver. 1.2.3 Python ver. 3.11x with 5 Folds cross validation and 10 random seeds. 2021-09-30 11:42:07"
2;1;" With gridsearched hyperparameters"
3;1;" Score is accuracy"
3;2;" Execution time"
diff --git a/benchmark/tests/test_files/report.test b/benchmark/tests/test_files/report.test
index 94498b7..47d4a1a 100644
--- a/benchmark/tests/test_files/report.test
+++ b/benchmark/tests/test_files/report.test
@@ -1,5 +1,5 @@
[94m************************************************************************************************************************
-[94m* Report STree ver. 1.2.3 with 5 Folds cross validation and 10 random seeds. 2021-09-30 11:42:07 *
+[94m* STree ver. 1.2.3 Python ver. 3.11x with 5 Folds cross validation and 10 random seeds. 2021-09-30 11:42:07 *
[94m* With gridsearched hyperparameters *
[94m* Random seeds: [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1] Stratified: False *
[94m* Execution took 624.25 seconds, 0.17 hours, on iMac27 *
diff --git a/benchmark/tests/test_files/report_compared.test b/benchmark/tests/test_files/report_compared.test
index 46c6f6c..018e3ee 100644
--- a/benchmark/tests/test_files/report_compared.test
+++ b/benchmark/tests/test_files/report_compared.test
@@ -1,5 +1,5 @@
[94m************************************************************************************************************************
-[94m* Report STree ver. 1.2.3 with 5 Folds cross validation and 10 random seeds. 2021-09-30 11:42:07 *
+[94m* STree ver. 1.2.3 Python ver. 3.11x with 5 Folds cross validation and 10 random seeds. 2021-09-30 11:42:07 *
[94m* With gridsearched hyperparameters *
[94m* Random seeds: [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1] Stratified: False *
[94m* Execution took 624.25 seconds, 0.17 hours, on iMac27 *