From 4838e3ef8bb30086343548325de46223d502032c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Thu, 3 Mar 2022 17:34:21 +0100 Subject: [PATCH] Add --dataset to main.py --- src/Experiments.py | 17 ++++++++++------- src/Results.py | 2 +- src/main.py | 17 ++++++++++++++++- 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/src/Experiments.py b/src/Experiments.py index 8c1706b..8470771 100644 --- a/src/Experiments.py +++ b/src/Experiments.py @@ -27,13 +27,16 @@ class Diterator: class Datasets: - def __init__(self): - try: - with open(os.path.join(Folders.data, Files.index)) as f: - self.data_sets = f.read().splitlines() - except FileNotFoundError: - with open(os.path.join("..", Folders.data, Files.index)) as f: - self.data_sets = f.read().splitlines() + def __init__(self, dataset=None): + if dataset is None: + try: + with open(os.path.join(Folders.data, Files.index)) as f: + self.data_sets = f.read().splitlines() + except FileNotFoundError: + with open(os.path.join("..", Folders.data, Files.index)) as f: + self.data_sets = f.read().splitlines() + else: + self.data_sets = [dataset] def load(self, name): try: diff --git a/src/Results.py b/src/Results.py index 52564f4..1da1755 100644 --- a/src/Results.py +++ b/src/Results.py @@ -622,7 +622,7 @@ class Benchmark: column += 2 row += 1 column = 1 - for _ in range(len(self._results)): + for _ in range(len(self._models)): sheet.write(row, column, "Score", merge_format) sheet.write(row, column + 1, "Stdev", merge_format) column += 2 diff --git a/src/main.py b/src/main.py index cbf2da3..2964bcb 100755 --- a/src/main.py +++ b/src/main.py @@ -1,4 +1,5 @@ #!/usr/bin/env python +import os import argparse from Experiments import Experiment, Datasets from Results import Report @@ -78,6 +79,14 @@ def parse_arguments(): required=True, help="Stratified", ) + ap.add_argument( + "-d", + "--dataset", + type=str, + required=False, + default=None, + help="Experiment with only this dataset", + ) args = ap.parse_args() return ( args.stratified, @@ -90,6 +99,7 @@ def parse_arguments(): args.paramfile, args.report, args.title, + args.dataset, ) @@ -104,12 +114,14 @@ def parse_arguments(): paramfile, report, experiment_title, + dataset, ) = parse_arguments() +report = report or dataset is not None job = Experiment( score_name=score, model_name=model, stratified=stratified, - datasets=Datasets(), + datasets=Datasets(dataset=dataset), hyperparams_dict=hyperparameters, hyperparams_file=paramfile, progress_bar=not quiet, @@ -122,3 +134,6 @@ if report: result_file = job.get_output_file() report = Report(result_file) report.report() +if dataset is not None: + print(f"Partial result file removed: {result_file}") + os.remove(result_file)