Add --dataset to main.py

This commit is contained in:
2022-03-03 17:34:21 +01:00
parent dabbbe3fd8
commit 4838e3ef8b
3 changed files with 27 additions and 9 deletions

View File

@@ -27,13 +27,16 @@ class Diterator:
class Datasets:
def __init__(self):
try:
with open(os.path.join(Folders.data, Files.index)) as f:
self.data_sets = f.read().splitlines()
except FileNotFoundError:
with open(os.path.join("..", Folders.data, Files.index)) as f:
self.data_sets = f.read().splitlines()
def __init__(self, dataset=None):
if dataset is None:
try:
with open(os.path.join(Folders.data, Files.index)) as f:
self.data_sets = f.read().splitlines()
except FileNotFoundError:
with open(os.path.join("..", Folders.data, Files.index)) as f:
self.data_sets = f.read().splitlines()
else:
self.data_sets = [dataset]
def load(self, name):
try:

View File

@@ -622,7 +622,7 @@ class Benchmark:
column += 2
row += 1
column = 1
for _ in range(len(self._results)):
for _ in range(len(self._models)):
sheet.write(row, column, "Score", merge_format)
sheet.write(row, column + 1, "Stdev", merge_format)
column += 2

View File

@@ -1,4 +1,5 @@
#!/usr/bin/env python
import os
import argparse
from Experiments import Experiment, Datasets
from Results import Report
@@ -78,6 +79,14 @@ def parse_arguments():
required=True,
help="Stratified",
)
ap.add_argument(
"-d",
"--dataset",
type=str,
required=False,
default=None,
help="Experiment with only this dataset",
)
args = ap.parse_args()
return (
args.stratified,
@@ -90,6 +99,7 @@ def parse_arguments():
args.paramfile,
args.report,
args.title,
args.dataset,
)
@@ -104,12 +114,14 @@ def parse_arguments():
paramfile,
report,
experiment_title,
dataset,
) = parse_arguments()
report = report or dataset is not None
job = Experiment(
score_name=score,
model_name=model,
stratified=stratified,
datasets=Datasets(),
datasets=Datasets(dataset=dataset),
hyperparams_dict=hyperparameters,
hyperparams_file=paramfile,
progress_bar=not quiet,
@@ -122,3 +134,6 @@ if report:
result_file = job.get_output_file()
report = Report(result_file)
report.report()
if dataset is not None:
print(f"Partial result file removed: {result_file}")
os.remove(result_file)