Refactor Datasets

This commit is contained in:
2022-11-22 16:26:04 +01:00
parent 93f0db36fa
commit 8aa76c27c3
4 changed files with 51 additions and 47 deletions

View File

@@ -12,11 +12,8 @@ jobs:
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
strategy: strategy:
matrix: matrix:
os: [macos-latest, ubuntu-latest] os: [ubuntu-latest]
python: ["3.10", "3.11"] python: ["3.10", "3.11"]
exclude:
- os: macos-latest
python: "3.11"
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3

View File

@@ -31,6 +31,7 @@ class DatasetsArff:
data = arff.loadarff(file_name) data = arff.loadarff(file_name)
df = pd.DataFrame(data[0]) df = pd.DataFrame(data[0])
df.dropna(axis=0, how="any", inplace=True) df.dropna(axis=0, how="any", inplace=True)
self.dataset = df
X = df.drop(class_name, axis=1) X = df.drop(class_name, axis=1)
self.features = X.columns self.features = X.columns
self.class_name = class_name self.class_name = class_name
@@ -55,8 +56,12 @@ class DatasetsTanveer:
sep="\t", sep="\t",
index_col=0, index_col=0,
) )
X = data.drop("clase", axis=1).to_numpy() X = data.drop("clase", axis=1)
self.features = X.columns
X = X.to_numpy()
y = data["clase"].to_numpy() y = data["clase"].to_numpy()
self.dataset = data
self.class_name = "clase"
return X, y return X, y
@@ -77,8 +82,11 @@ class DatasetsSurcov:
) )
data.dropna(axis=0, how="any", inplace=True) data.dropna(axis=0, how="any", inplace=True)
self.columns = data.columns self.columns = data.columns
col_list = ["class"] X = data.drop(["class"], axis=1)
X = data.drop(col_list, axis=1).to_numpy() self.features = X.columns
self.class_name = "class"
self.dataset = data
X = X.to_numpy()
y = data["class"].to_numpy() y = data["class"].to_numpy()
return X, y return X, y
@@ -86,43 +94,42 @@ class DatasetsSurcov:
class Datasets: class Datasets:
def __init__(self, dataset_name=None): def __init__(self, dataset_name=None):
envData = EnvData.load() envData = EnvData.load()
class_name = getattr( # DatasetsSurcov, DatasetsTanveer, DatasetsArff,...
source_name = getattr(
__import__(__name__), __import__(__name__),
f"Datasets{envData['source_data']}", f"Datasets{envData['source_data']}",
) )
self.load = ( self.discretize = envData["discretize"] == "1"
self.load_discretized self.dataset = source_name()
if envData["discretize"] == "1"
else self.load_continuous
)
self.dataset = class_name()
self.class_names = [] self.class_names = []
self._load_names() self.data_sets = []
if dataset_name is not None: # initialize self.class_names & self.data_sets
try: class_names, sets = self._init_names(dataset_name)
class_name = self.class_names[ self.class_names = class_names
self.data_sets.index(dataset_name) self.data_sets = sets
]
self.class_names = [class_name]
except ValueError:
raise ValueError(f"Unknown dataset: {dataset_name}")
self.data_sets = [dataset_name]
def _load_names(self): def _init_names(self, dataset_name):
file_name = os.path.join(self.dataset.folder(), Files.index) file_name = os.path.join(self.dataset.folder(), Files.index)
default_class = "class" default_class = "class"
with open(file_name) as f: with open(file_name) as f:
self.data_sets = f.read().splitlines() sets = f.read().splitlines()
self.class_names = [default_class] * len(self.data_sets) class_names = [default_class] * len(sets)
if "," in self.data_sets[0]: if "," in sets[0]:
result = [] result = []
class_names = [] class_names = []
for data in self.data_sets: for data in sets:
name, class_name = data.split(",") name, class_name = data.split(",")
result.append(name) result.append(name)
class_names.append(class_name) class_names.append(class_name)
self.data_sets = result sets = result
self.class_names = class_names # Set as dataset list the dataset passed as argument
if dataset_name is None:
return class_names, sets
try:
class_name = class_names[sets.index(dataset_name)]
except ValueError:
raise ValueError(f"Unknown dataset: {dataset_name}")
return [class_name], [dataset_name]
def get_attributes(self, name): def get_attributes(self, name):
class Attributes: class Attributes:
@@ -148,14 +155,25 @@ class Datasets:
def get_class_name(self): def get_class_name(self):
return self.dataset.class_name return self.dataset.class_name
def load_continuous(self, name): def get_dataset(self):
return self.dataset.dataset
def load(self, name, dataframe=False):
try: try:
class_name = self.class_names[self.data_sets.index(name)] class_name = self.class_names[self.data_sets.index(name)]
return self.dataset.load(name, class_name) X, y = self.dataset.load(name, class_name)
if self.discretize:
X = self.discretize_dataset(X, y)
dataset = pd.DataFrame(X, columns=self.get_features())
dataset[self.get_class_name()] = y
self.dataset.dataset = dataset
if dataframe:
return self.get_dataset()
return X, y
except (ValueError, FileNotFoundError): except (ValueError, FileNotFoundError):
raise ValueError(f"Unknown dataset: {name}") raise ValueError(f"Unknown dataset: {name}")
def discretize(self, X, y): def discretize_dataset(self, X, y):
"""Supervised discretization with Fayyad and Irani's MDLP algorithm. """Supervised discretization with Fayyad and Irani's MDLP algorithm.
Parameters Parameters
@@ -173,14 +191,5 @@ class Datasets:
Xdisc = discretiz.fit_transform(X, y) Xdisc = discretiz.fit_transform(X, y)
return Xdisc return Xdisc
def load_discretized(self, name, dataframe=False):
X, yd = self.load_continuous(name)
Xd = self.discretize(X, yd)
dataset = pd.DataFrame(Xd, columns=self.get_features())
dataset[self.get_class_name()] = yd
if dataframe:
return dataset
return Xd, yd
def __iter__(self) -> Diterator: def __iter__(self) -> Diterator:
return Diterator(self.data_sets) return Diterator(self.data_sets)

View File

@@ -33,8 +33,8 @@ class DatasetTest(TestBase):
def test_load_dataframe(self): def test_load_dataframe(self):
self.set_env(".env.arff") self.set_env(".env.arff")
dt = Datasets() dt = Datasets()
X, y = dt.load_discretized("iris", dataframe=False) X, y = dt.load("iris", dataframe=False)
dataset = dt.load_discretized("iris", dataframe=True) dataset = dt.load("iris", dataframe=True)
class_name = dt.get_class_name() class_name = dt.get_class_name()
features = dt.get_features() features = dt.get_features()
self.assertListEqual(y.tolist(), dataset[class_name].tolist()) self.assertListEqual(y.tolist(), dataset[class_name].tolist())

View File

@@ -61,8 +61,6 @@ setuptools.setup(
classifiers=[ classifiers=[
"Development Status :: 4 - Beta", "Development Status :: 4 - Beta",
"License :: OSI Approved :: " + get_data("license"), "License :: OSI Approved :: " + get_data("license"),
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.10",
"Natural Language :: English", "Natural Language :: English",
"Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Artificial Intelligence",