diff --git a/report.py b/report.py index 0800bfc..be0a9a2 100644 --- a/report.py +++ b/report.py @@ -1,4 +1,5 @@ import sys +import numpy as np from experimentation.Sets import Datasets set_name = "aaai" @@ -7,5 +8,20 @@ if len(sys.argv) > 1: if set_name != "aaai" and set_name != "tanveer": print("First parameter has to be one of: {aaai, tanveer}") exit(1) +if len(sys.argv) > 2: + csv = sys.argv[2] == "tex" +else: + csv = False + datasets = Datasets(False, False, set_name) -datasets.report() +if csv: + for number, dataset in enumerate(datasets): + X, y = datasets.load(dataset[0]) # type: ignore + samples, features = X.shape + classes = len(np.unique(y)) + print( + "%d & %s & %d & %d & %d \\\\" + % (number + 1, dataset[0], X.shape[0], X.shape[1], classes) + ) +else: + datasets.report()