From 8cf823e843785a2ed218c644915feb95acf8e011 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana?= Date: Tue, 1 Nov 2022 12:24:50 +0100 Subject: [PATCH] Add custom seeds to .env --- .env.dist | 1 + benchmark/Experiments.py | 9 ++++++--- benchmark/Results.py | 1 - benchmark/tests/.env | 1 + benchmark/tests/.env.arff | 1 + benchmark/tests/.env.dist | 1 + benchmark/tests/.env.surcov | 1 + benchmark/tests/Dataset_test.py | 7 ++++++- benchmark/tests/Util_test.py | 1 + 9 files changed, 18 insertions(+), 5 deletions(-) diff --git a/.env.dist b/.env.dist index 93ede38..a540dfe 100644 --- a/.env.dist +++ b/.env.dist @@ -4,3 +4,4 @@ n_folds=5 model=ODTE stratified=0 source_data=Tanveer +seeds=[57, 31, 1714, 17, 23, 79, 83, 97, 7, 1] diff --git a/benchmark/Experiments.py b/benchmark/Experiments.py index 6e4b763..6092955 100644 --- a/benchmark/Experiments.py +++ b/benchmark/Experiments.py @@ -16,10 +16,13 @@ from sklearn.model_selection import ( from .Utils import Folders, Files, NO_RESULTS from .Datasets import Datasets from .Models import Models +from .Arguments import EnvData class Randomized: - seeds = [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1] + @staticmethod + def seeds(): + return json.loads(EnvData.load()["seeds"]) class BestResults: @@ -155,7 +158,7 @@ class Experiment: self.platform = platform self.progress_bar = progress_bar self.folds = folds - self.random_seeds = Randomized.seeds + self.random_seeds = Randomized.seeds() self.results = [] self.duration = 0 self._init_experiment() @@ -308,7 +311,7 @@ class GridSearch: self.progress_bar = progress_bar self.folds = folds self.platform = platform - self.random_seeds = Randomized.seeds + self.random_seeds = Randomized.seeds() self.grid_file = os.path.join( Folders.results, Files.grid_input(score_name, model_name) ) diff --git a/benchmark/Results.py b/benchmark/Results.py index b76434a..bb34163 100644 --- a/benchmark/Results.py +++ b/benchmark/Results.py @@ -16,7 +16,6 @@ from .Utils import ( Symbols, TextColor, NO_RESULTS, - PYTHON_VERSION, ) diff --git a/benchmark/tests/.env b/benchmark/tests/.env index 819f93f..31a99ab 100644 --- a/benchmark/tests/.env +++ b/benchmark/tests/.env @@ -5,3 +5,4 @@ model=ODTE stratified=0 # Source of data Tanveer/Surcov source_data=Tanveer +seeds=[57, 31, 1714, 17, 23, 79, 83, 97, 7, 1] diff --git a/benchmark/tests/.env.arff b/benchmark/tests/.env.arff index 3cff1df..ab8956d 100644 --- a/benchmark/tests/.env.arff +++ b/benchmark/tests/.env.arff @@ -4,3 +4,4 @@ n_folds=5 model=ODTE stratified=0 source_data=Arff +seeds=[271, 314, 171] \ No newline at end of file diff --git a/benchmark/tests/.env.dist b/benchmark/tests/.env.dist index 819f93f..31a99ab 100644 --- a/benchmark/tests/.env.dist +++ b/benchmark/tests/.env.dist @@ -5,3 +5,4 @@ model=ODTE stratified=0 # Source of data Tanveer/Surcov source_data=Tanveer +seeds=[57, 31, 1714, 17, 23, 79, 83, 97, 7, 1] diff --git a/benchmark/tests/.env.surcov b/benchmark/tests/.env.surcov index 01deb63..63cc579 100644 --- a/benchmark/tests/.env.surcov +++ b/benchmark/tests/.env.surcov @@ -5,3 +5,4 @@ model=ODTE stratified=0 # Source of data Tanveer/Surcov source_data=Surcov +seeds=[57, 31, 1714, 17, 23, 79, 83, 97, 7, 1] \ No newline at end of file diff --git a/benchmark/tests/Dataset_test.py b/benchmark/tests/Dataset_test.py index ca28453..f00a982 100644 --- a/benchmark/tests/Dataset_test.py +++ b/benchmark/tests/Dataset_test.py @@ -23,7 +23,12 @@ class DatasetTest(TestBase): def test_Randomized(self): expected = [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1] - self.assertSequenceEqual(Randomized.seeds, expected) + self.assertSequenceEqual(Randomized.seeds(), expected) + + def test_Randomized_3_seeds(self): + self.set_env(".env.arff") + expected = [271, 314, 171] + self.assertSequenceEqual(Randomized.seeds(), expected) def test_Datasets_iterator(self): test = { diff --git a/benchmark/tests/Util_test.py b/benchmark/tests/Util_test.py index 6084f52..1020a5e 100644 --- a/benchmark/tests/Util_test.py +++ b/benchmark/tests/Util_test.py @@ -178,6 +178,7 @@ class UtilTest(TestBase): "model": "ODTE", "stratified": "0", "source_data": "Tanveer", + "seeds": "[57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]", } computed = EnvData().load() self.assertDictEqual(computed, expected)