From 1c5f1977e5db13fbfebc2db444aeea5a472dea13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Thu, 28 Oct 2021 11:55:40 +0200 Subject: [PATCH] Complete iwss based implementation (#2) --- mufs/Selection.py | 1 + mufs/tests/MUFS_test.py | 13 +++++++++---- mufs/tests/Metrics_test.py | 2 +- mufs/tests/__init__.py | 6 +++--- 4 files changed, 14 insertions(+), 8 deletions(-) diff --git a/mufs/Selection.py b/mufs/Selection.py index 7f96003..eb328e6 100755 --- a/mufs/Selection.py +++ b/mufs/Selection.py @@ -318,6 +318,7 @@ class MUFS: self._scores.append(merit_new) else: candidates.pop() + break if len(candidates) == self._max_features: break self._result = candidates diff --git a/mufs/tests/MUFS_test.py b/mufs/tests/MUFS_test.py index 312620d..1b60d61 100755 --- a/mufs/tests/MUFS_test.py +++ b/mufs/tests/MUFS_test.py @@ -1,11 +1,14 @@ import unittest +import os +import pandas as pd +import numpy as np from mdlp import MDLP from sklearn.datasets import load_wine, load_iris from ..Selection import MUFS -class MUFS_test(unittest.TestCase): +class MUFSTest(unittest.TestCase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) mdlp = MDLP(random_state=1) @@ -175,9 +178,6 @@ class MUFS_test(unittest.TestCase): mufs.iwss(self.X_w, self.y_w, -0.01) def test_iwss_better_merit_condition(self): - import pandas as pd - import os - folder = os.path.dirname(os.path.abspath(__file__)) data = pd.read_csv( os.path.join(folder, "balloons_R.dat"), @@ -189,3 +189,8 @@ class MUFS_test(unittest.TestCase): mufs = MUFS() expected = [0, 2, 3, 1] self.assertListEqual(expected, mufs.iwss(X, y, 0.3).get_results()) + + def test_iwss_empty(self): + mufs = MUFS() + X = np.delete(self.X_i, [0, 1], 1) + self.assertListEqual(mufs.iwss(X, self.y_i, 0.3).get_results(), [1, 0]) diff --git a/mufs/tests/Metrics_test.py b/mufs/tests/Metrics_test.py index 18ac46a..3a2a270 100755 --- a/mufs/tests/Metrics_test.py +++ b/mufs/tests/Metrics_test.py @@ -6,7 +6,7 @@ from mdlp import MDLP from ..Selection import Metrics -class Metrics_test(unittest.TestCase): +class MetricsTest(unittest.TestCase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) mdlp = MDLP(random_state=1) diff --git a/mufs/tests/__init__.py b/mufs/tests/__init__.py index e914937..466683c 100644 --- a/mufs/tests/__init__.py +++ b/mufs/tests/__init__.py @@ -1,4 +1,4 @@ -from .MUFS_test import MUFS_test -from .Metrics_test import Metrics_test +from .MUFS_test import MUFSTest +from .Metrics_test import MetricsTest -__all__ = ["MUFS_test", "Metrics_test"] +__all__ = ["MUFSTest", "MetricsTest"]