Add Arff as source_data for datasets

This commit is contained in:
2022-10-24 21:04:07 +02:00
parent 7875e2e6ac
commit 34b3bd94de
2 changed files with 23 additions and 9 deletions

View File

@@ -24,13 +24,12 @@ class DatasetsArff:
def folder():
return "datasets"
def load(self, name, class_name="class"):
def load(self, name, class_name):
file_name = os.path.join(self.folder(), self.dataset_names(name))
data = arff.loadarff(file_name)
df = pd.DataFrame(data[0])
y = df[class_name]
X = data.drop(class_name, axis=1).to_numpy()
y = data[class_name].to_numpy()
X = df.drop(class_name, axis=1).to_numpy()
y = df[class_name].to_numpy()
return X, y
@@ -43,7 +42,7 @@ class DatasetsTanveer:
def folder():
return "data"
def load(self, name):
def load(self, name, _):
file_name = os.path.join(self.folder(), self.dataset_names(name))
data = pd.read_csv(
file_name,
@@ -64,7 +63,7 @@ class DatasetsSurcov:
def folder():
return "datasets"
def load(self, name):
def load(self, name, _):
file_name = os.path.join(self.folder(), self.dataset_names(name))
data = pd.read_csv(
file_name,
@@ -80,23 +79,38 @@ class DatasetsSurcov:
class Datasets:
def __init__(self, dataset_name=None):
default_class = "class"
envData = EnvData.load()
class_name = getattr(
__import__(__name__),
f"Datasets{envData['source_data']}",
)
self.dataset = class_name()
self.class_names = []
if dataset_name is None:
file_name = os.path.join(self.dataset.folder(), Files.index)
with open(file_name) as f:
self.data_sets = f.read().splitlines()
self.class_names = [default_class] * len(self.data_sets)
if "," in self.data_sets[0]:
result = []
class_names = []
for data in self.data_sets:
name, class_name = data.split(",")
result.append(name)
class_names.append(class_name)
self.data_sets = result
self.class_names = class_names
else:
self.data_sets = [dataset_name]
self.class_names = [default_class]
def load(self, name):
try:
return self.dataset.load(name)
except FileNotFoundError:
class_name = self.class_names[self.data_sets.index(name)]
return self.dataset.load(name, class_name)
except (ValueError, FileNotFoundError):
raise ValueError(f"Unknown dataset: {name}")
def __iter__(self) -> Diterator:

View File

@@ -1,4 +1,4 @@
from .Datasets import Datasets, DatasetsSurcov, DatasetsTanveer
from .Datasets import Datasets, DatasetsSurcov, DatasetsTanveer, DatasetsArff
from .Experiments import Experiment
from .Results import Report, Summary