diff --git a/benchmark/Datasets.py b/benchmark/Datasets.py index e3735e7..cf06740 100644 --- a/benchmark/Datasets.py +++ b/benchmark/Datasets.py @@ -24,14 +24,16 @@ class DatasetsArff: def folder(): return "datasets" - def load(self, name, class_name): + def load(self, name, class_name, dataframe): file_name = os.path.join(self.folder(), self.dataset_names(name)) data = arff.loadarff(file_name) df = pd.DataFrame(data[0]) df = df.dropna() - X = df.drop(class_name, axis=1).to_numpy() + X = df.drop(class_name, axis=1) + self.features = X.columns + self.class_name = class_name y, _ = pd.factorize(df[class_name]) - return X, y + return df if dataframe else (X.to_numpy(), y) class DatasetsTanveer: @@ -43,7 +45,7 @@ class DatasetsTanveer: def folder(): return "data" - def load(self, name, _): + def load(self, name, *args): file_name = os.path.join(self.folder(), self.dataset_names(name)) data = pd.read_csv( file_name, @@ -64,7 +66,7 @@ class DatasetsSurcov: def folder(): return "datasets" - def load(self, name, _): + def load(self, name, *args): file_name = os.path.join(self.folder(), self.dataset_names(name)) data = pd.read_csv( file_name, @@ -115,10 +117,10 @@ class Datasets: self.data_sets = result self.class_names = class_names - def load(self, name): + def load(self, name, dataframe=False): try: class_name = self.class_names[self.data_sets.index(name)] - return self.dataset.load(name, class_name) + return self.dataset.load(name, class_name, dataframe) except (ValueError, FileNotFoundError): raise ValueError(f"Unknown dataset: {name}")