Add Arff as source_data for datasets

This commit is contained in:
2022-10-24 21:04:07 +02:00
parent 7875e2e6ac
commit 34b3bd94de
2 changed files with 23 additions and 9 deletions

View File

@@ -24,13 +24,12 @@ class DatasetsArff:
def folder(): def folder():
return "datasets" return "datasets"
def load(self, name, class_name="class"): def load(self, name, class_name):
file_name = os.path.join(self.folder(), self.dataset_names(name)) file_name = os.path.join(self.folder(), self.dataset_names(name))
data = arff.loadarff(file_name) data = arff.loadarff(file_name)
df = pd.DataFrame(data[0]) df = pd.DataFrame(data[0])
y = df[class_name] X = df.drop(class_name, axis=1).to_numpy()
X = data.drop(class_name, axis=1).to_numpy() y = df[class_name].to_numpy()
y = data[class_name].to_numpy()
return X, y return X, y
@@ -43,7 +42,7 @@ class DatasetsTanveer:
def folder(): def folder():
return "data" return "data"
def load(self, name): def load(self, name, _):
file_name = os.path.join(self.folder(), self.dataset_names(name)) file_name = os.path.join(self.folder(), self.dataset_names(name))
data = pd.read_csv( data = pd.read_csv(
file_name, file_name,
@@ -64,7 +63,7 @@ class DatasetsSurcov:
def folder(): def folder():
return "datasets" return "datasets"
def load(self, name): def load(self, name, _):
file_name = os.path.join(self.folder(), self.dataset_names(name)) file_name = os.path.join(self.folder(), self.dataset_names(name))
data = pd.read_csv( data = pd.read_csv(
file_name, file_name,
@@ -80,23 +79,38 @@ class DatasetsSurcov:
class Datasets: class Datasets:
def __init__(self, dataset_name=None): def __init__(self, dataset_name=None):
default_class = "class"
envData = EnvData.load() envData = EnvData.load()
class_name = getattr( class_name = getattr(
__import__(__name__), __import__(__name__),
f"Datasets{envData['source_data']}", f"Datasets{envData['source_data']}",
) )
self.dataset = class_name() self.dataset = class_name()
self.class_names = []
if dataset_name is None: if dataset_name is None:
file_name = os.path.join(self.dataset.folder(), Files.index) file_name = os.path.join(self.dataset.folder(), Files.index)
with open(file_name) as f: with open(file_name) as f:
self.data_sets = f.read().splitlines() self.data_sets = f.read().splitlines()
self.class_names = [default_class] * len(self.data_sets)
if "," in self.data_sets[0]:
result = []
class_names = []
for data in self.data_sets:
name, class_name = data.split(",")
result.append(name)
class_names.append(class_name)
self.data_sets = result
self.class_names = class_names
else: else:
self.data_sets = [dataset_name] self.data_sets = [dataset_name]
self.class_names = [default_class]
def load(self, name): def load(self, name):
try: try:
return self.dataset.load(name) class_name = self.class_names[self.data_sets.index(name)]
except FileNotFoundError: return self.dataset.load(name, class_name)
except (ValueError, FileNotFoundError):
raise ValueError(f"Unknown dataset: {name}") raise ValueError(f"Unknown dataset: {name}")
def __iter__(self) -> Diterator: def __iter__(self) -> Diterator:

View File

@@ -1,4 +1,4 @@
from .Datasets import Datasets, DatasetsSurcov, DatasetsTanveer from .Datasets import Datasets, DatasetsSurcov, DatasetsTanveer, DatasetsArff
from .Experiments import Experiment from .Experiments import Experiment
from .Results import Report, Summary from .Results import Report, Summary