diff --git a/benchmark/Datasets.py b/benchmark/Datasets.py index 20a4894..0150623 100644 --- a/benchmark/Datasets.py +++ b/benchmark/Datasets.py @@ -24,13 +24,12 @@ class DatasetsArff: def folder(): return "datasets" - def load(self, name, class_name="class"): + def load(self, name, class_name): file_name = os.path.join(self.folder(), self.dataset_names(name)) data = arff.loadarff(file_name) df = pd.DataFrame(data[0]) - y = df[class_name] - X = data.drop(class_name, axis=1).to_numpy() - y = data[class_name].to_numpy() + X = df.drop(class_name, axis=1).to_numpy() + y = df[class_name].to_numpy() return X, y @@ -43,7 +42,7 @@ class DatasetsTanveer: def folder(): return "data" - def load(self, name): + def load(self, name, _): file_name = os.path.join(self.folder(), self.dataset_names(name)) data = pd.read_csv( file_name, @@ -64,7 +63,7 @@ class DatasetsSurcov: def folder(): return "datasets" - def load(self, name): + def load(self, name, _): file_name = os.path.join(self.folder(), self.dataset_names(name)) data = pd.read_csv( file_name, @@ -80,23 +79,38 @@ class DatasetsSurcov: class Datasets: def __init__(self, dataset_name=None): + default_class = "class" envData = EnvData.load() class_name = getattr( __import__(__name__), f"Datasets{envData['source_data']}", ) self.dataset = class_name() + self.class_names = [] if dataset_name is None: file_name = os.path.join(self.dataset.folder(), Files.index) with open(file_name) as f: self.data_sets = f.read().splitlines() + self.class_names = [default_class] * len(self.data_sets) + if "," in self.data_sets[0]: + result = [] + class_names = [] + for data in self.data_sets: + name, class_name = data.split(",") + result.append(name) + class_names.append(class_name) + self.data_sets = result + self.class_names = class_names + else: self.data_sets = [dataset_name] + self.class_names = [default_class] def load(self, name): try: - return self.dataset.load(name) - except FileNotFoundError: + class_name = self.class_names[self.data_sets.index(name)] + return self.dataset.load(name, class_name) + except (ValueError, FileNotFoundError): raise ValueError(f"Unknown dataset: {name}") def __iter__(self) -> Diterator: diff --git a/benchmark/__init__.py b/benchmark/__init__.py index cac5b02..e26a7cf 100644 --- a/benchmark/__init__.py +++ b/benchmark/__init__.py @@ -1,4 +1,4 @@ -from .Datasets import Datasets, DatasetsSurcov, DatasetsTanveer +from .Datasets import Datasets, DatasetsSurcov, DatasetsTanveer, DatasetsArff from .Experiments import Experiment from .Results import Report, Summary