diff --git a/.vscode/launch.json b/.vscode/launch.json index 393a25e..d74b285 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -69,6 +69,17 @@ //"cwd": "/Users/rmontanana/Code/discretizbench", "cwd": "/home/rmontanana/Code/covbench", }, + { + "type": "lldb", + "request": "launch", + "name": "test", + "program": "${workspaceFolder}/build/tests/unit_tests", + "args": [ + "-c=\"Metrics Test\"", + // "-s", + ], + "cwd": "${workspaceFolder}/build/tests", + }, { "name": "Build & debug active file", "type": "cppdbg", diff --git a/src/BayesNet/Mst.cc b/src/BayesNet/Mst.cc index b915d76..11e4b2b 100644 --- a/src/BayesNet/Mst.cc +++ b/src/BayesNet/Mst.cc @@ -34,7 +34,7 @@ namespace bayesnet { void Graph::kruskal_algorithm() { // sort the edges ordered on decreasing weight - sort(G.begin(), G.end(), [](const auto& left, const auto& right) {return left.first > right.first;}); + stable_sort(G.begin(), G.end(), [](const auto& left, const auto& right) {return left.first > right.first;}); for (int i = 0; i < G.size(); i++) { int uSt, vEd; uSt = find_set(G[i].second.first); diff --git a/src/Platform/CMakeLists.txt b/src/Platform/CMakeLists.txt index 4e62a1f..05b8804 100644 --- a/src/Platform/CMakeLists.txt +++ b/src/Platform/CMakeLists.txt @@ -9,6 +9,7 @@ add_executable(b_main main.cc Folding.cc Experiment.cc Datasets.cc Dataset.cc Mo add_executable(b_manage manage.cc Results.cc Result.cc ReportConsole.cc ReportExcel.cc ReportBase.cc Datasets.cc Dataset.cc ExcelFile.cc) add_executable(b_list list.cc Datasets.cc Dataset.cc) add_executable(b_best best.cc BestResults.cc Result.cc Statistics.cc BestResultsExcel.cc ExcelFile.cc) +add_executable(testx testx.cpp Datasets.cc Dataset.cc Folding.cc) target_link_libraries(b_main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}") if (${CMAKE_HOST_SYSTEM_NAME} MATCHES "Linux") target_link_libraries(b_manage "${TORCH_LIBRARIES}" libxlsxwriter.so ArffFiles mdlp stdc++fs) @@ -17,4 +18,5 @@ else() target_link_libraries(b_manage "${TORCH_LIBRARIES}" "${XLSXWRITER_LIB}" ArffFiles mdlp) target_link_libraries(b_best Boost::boost "${XLSXWRITER_LIB}") endif() -target_link_libraries(b_list ArffFiles mdlp "${TORCH_LIBRARIES}") \ No newline at end of file +target_link_libraries(b_list ArffFiles mdlp "${TORCH_LIBRARIES}") +target_link_libraries(testx ArffFiles mdlp "${TORCH_LIBRARIES}") \ No newline at end of file diff --git a/src/Platform/Folding.cc b/src/Platform/Folding.cc index 48e03dd..d31f773 100644 --- a/src/Platform/Folding.cc +++ b/src/Platform/Folding.cc @@ -47,6 +47,7 @@ namespace platform { { stratified_indices = vector>(k); int fold_size = n / k; + cout << "Fold SIZE: " << fold_size << endl; // Compute class counts and indices auto class_indices = map>(); vector class_counts(*max_element(y.begin(), y.end()) + 1, 0); @@ -64,16 +65,20 @@ namespace platform { if (num_samples_to_take == 0) continue; auto remainder_samples_to_take = class_counts[label] % k; + cout << "Remainder samples to take: " << remainder_samples_to_take << endl; for (auto fold = 0; fold < k; ++fold) { auto it = next(class_indices[label].begin(), num_samples_to_take); move(class_indices[label].begin(), it, back_inserter(stratified_indices[fold])); // ## class_indices[label].erase(class_indices[label].begin(), it); } + auto chosen = vector(k, false); while (remainder_samples_to_take > 0) { int fold = (rand() % static_cast(k)); - if (stratified_indices[fold].size() == fold_size + 1) { + if (chosen.at(fold)) { continue; } + chosen[k] = true; + cout << "One goes to fold " << fold << " that had " << stratified_indices[fold].size() << " elements before" << endl; auto it = next(class_indices[label].begin(), 1); stratified_indices[fold].push_back(*class_indices[label].begin()); class_indices[label].erase(class_indices[label].begin(), it); diff --git a/src/Platform/testx.cpp b/src/Platform/testx.cpp new file mode 100644 index 0000000..7bc392b --- /dev/null +++ b/src/Platform/testx.cpp @@ -0,0 +1,65 @@ +#include "Folding.h" +#include "map" +#include "Datasets.h" +#include +#include +#include +using namespace std; +using namespace platform; + +string counts(vector y, vector indices) +{ + auto result = map(); + stringstream oss; + for (auto i = 0; i < indices.size(); ++i) { + result[y[indices[i]]]++; + } + string final_result = ""; + for (auto i = 0; i < result.size(); ++i) + oss << i << " -> " << setprecision(2) << fixed + << (double)result[i] * 100 / indices.size() << "% (" << result[i] << ") //"; + oss << endl; + return oss.str(); +} + +int main() +{ + map balance = { + {"iris", "33,33% (50) / 33,33% (50) / 33,33% (50)"}, + {"diabetes", "34,90% (268) / 65,10% (500)"}, + {"ecoli", "42,56% (143) / 22,92% (77) / 0,60% (2) / 0,60% (2) / 10,42% (35) / 5,95% (20) / 1,49% (5) / 15,48% (52)"}, + {"glass", "32,71% (70) / 7,94% (17) / 4,21% (9) / 35,51% (76) / 13,55% (29) / 6,07% (13)"} + }; + for (const auto& file_name : { "iris", "glass", "ecoli", "diabetes" }) { + auto dt = Datasets(true, "Arff"); + auto [X, y] = dt.getVectors(file_name); + //auto fold = KFold(5, 150); + auto fold = StratifiedKFold(5, y, -1); + cout << "***********************************************************************************************" << endl; + cout << "Dataset: " << file_name << endl; + cout << "NÂș Samples: " << dt.getNSamples(file_name) << endl; + cout << "Class states: " << dt.getNClasses(file_name) << endl; + cout << "Balance: " << balance.at(file_name) << endl; + for (int i = 0; i < 5; ++i) { + cout << "Fold: " << i << endl; + auto [train, test] = fold.getFold(i); + cout << "Train: "; + cout << "(" << train.size() << "): "; + // for (auto j = 0; j < static_cast(train.size()); j++) + // cout << train[j] << ", "; + cout << endl; + cout << "Train Statistics : " << counts(y, train); + cout << "-------------------------------------------------------------------------------" << endl; + cout << "Test: "; + cout << "(" << test.size() << "): "; + // for (auto j = 0; j < static_cast(test.size()); j++) + // cout << test[j] << ", "; + cout << endl; + cout << "Test Statistics: " << counts(y, test); + cout << "==============================================================================" << endl; + } + cout << "***********************************************************************************************" << endl; + } + +} + diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 800006b..bb273e5 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -5,7 +5,7 @@ if(ENABLE_TESTING) include_directories(${BayesNet_SOURCE_DIR}/lib/Files) include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp) include_directories(${BayesNet_SOURCE_DIR}/lib/json/include) - set(TEST_SOURCES TestBayesModels.cc TestBayesNetwork.cc TestBayesMetrics.cc TestUtils.cc ${BayesNet_SOURCE_DIR}/src/Platform/Folding.cc ${BayesNet_SOURCES}) + set(TEST_SOURCES TestBayesModels.cc TestBayesNetwork.cc TestBayesMetrics.cc TestFolding.cc TestUtils.cc ${BayesNet_SOURCE_DIR}/src/Platform/Folding.cc ${BayesNet_SOURCES}) add_executable(${TEST_MAIN} ${TEST_SOURCES}) target_link_libraries(${TEST_MAIN} PUBLIC "${TORCH_LIBRARIES}" ArffFiles mdlp Catch2::Catch2WithMain) add_test(NAME ${TEST_MAIN} COMMAND ${TEST_MAIN}) diff --git a/tests/TestBayesMetrics.cc b/tests/TestBayesMetrics.cc index f6f420d..7f80c1b 100644 --- a/tests/TestBayesMetrics.cc +++ b/tests/TestBayesMetrics.cc @@ -22,19 +22,13 @@ TEST_CASE("Metrics Test", "[BayesNet]") {"diabetes", 0.0345470614} }; map>> resultsMST = { - {"glass", {{0,6}, {0,5}, {0,3}, {5,1}, {5,8}, {6,2}, {6,7}, {7,4}}}, + {"glass", {{0,6}, {0,5}, {0,3}, {6,2}, {6,7}, {5,1}, {5,8}, {5,4}}}, {"iris", {{0,1},{0,2},{1,3}}}, {"ecoli", {{0,1}, {0,2}, {1,5}, {1,3}, {5,6}, {5,4}}}, {"diabetes", {{0,7}, {0,2}, {0,6}, {2,3}, {3,4}, {3,5}, {4,1}}} }; - auto [XDisc, yDisc, featuresDisc, classNameDisc, statesDisc] = loadDataset(file_name, true, true); - int classNumStates = statesDisc.at(classNameDisc).size(); - auto yresized = torch::transpose(yDisc.view({ yDisc.size(0), 1 }), 0, 1); - torch::Tensor dataset = torch::cat({ XDisc, yresized }, 0); - int nSamples = dataset.size(1); - double epsilon = 1e-5; - torch::Tensor weights = torch::full({ nSamples }, 1.0 / nSamples, torch::kDouble); - bayesnet::Metrics metrics(dataset, featuresDisc, classNameDisc, classNumStates); + auto raw = RawDatasets(file_name, true); + bayesnet::Metrics metrics(raw.dataset, raw.featurest, raw.classNamet, raw.classNumStates); SECTION("Test Constructor") { @@ -43,21 +37,21 @@ TEST_CASE("Metrics Test", "[BayesNet]") SECTION("Test SelectKBestWeighted") { - vector kBest = metrics.SelectKBestWeighted(weights, true, resultsKBest.at(file_name).first); + vector kBest = metrics.SelectKBestWeighted(raw.weights, true, resultsKBest.at(file_name).first); REQUIRE(kBest.size() == resultsKBest.at(file_name).first); REQUIRE(kBest == resultsKBest.at(file_name).second); } SECTION("Test Mutual Information") { - auto result = metrics.mutualInformation(dataset.index({ 1, "..." }), dataset.index({ 2, "..." }), weights); - REQUIRE(result == Catch::Approx(resultsMI.at(file_name)).epsilon(epsilon)); + auto result = metrics.mutualInformation(raw.dataset.index({ 1, "..." }), raw.dataset.index({ 2, "..." }), raw.weights); + REQUIRE(result == Catch::Approx(resultsMI.at(file_name)).epsilon(raw.epsilon)); } SECTION("Test Maximum Spanning Tree") { - auto weights_matrix = metrics.conditionalEdge(weights); - auto result = metrics.maximumSpanningTree(featuresDisc, weights_matrix, 0); + auto weights_matrix = metrics.conditionalEdge(raw.weights); + auto result = metrics.maximumSpanningTree(raw.featurest, weights_matrix, 0); REQUIRE(result == resultsMST.at(file_name)); } } \ No newline at end of file diff --git a/tests/TestBayesModels.cc b/tests/TestBayesModels.cc index c0d9b25..51a5d27 100644 --- a/tests/TestBayesModels.cc +++ b/tests/TestBayesModels.cc @@ -34,83 +34,81 @@ TEST_CASE("Test Bayesian Classifiers score", "[BayesNet]") }; string file_name = GENERATE("glass", "iris", "ecoli", "diabetes"); - auto [XCont, yCont, featuresCont, classNameCont, statesCont] = loadDataset(file_name, true, false); - auto [XDisc, yDisc, featuresDisc, classNameDisc, statesDisc] = loadFile(file_name); - double epsilon = 1e-5; + auto raw = RawDatasets(file_name, false); SECTION("Test TAN classifier (" + file_name + ")") { auto clf = bayesnet::TAN(); - clf.fit(XDisc, yDisc, featuresDisc, classNameDisc, statesDisc); - auto score = clf.score(XDisc, yDisc); + clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv); + auto score = clf.score(raw.Xv, raw.yv); //scores[{file_name, "TAN"}] = score; - REQUIRE(score == Catch::Approx(scores[{file_name, "TAN"}]).epsilon(epsilon)); + REQUIRE(score == Catch::Approx(scores[{file_name, "TAN"}]).epsilon(raw.epsilon)); } SECTION("Test TANLd classifier (" + file_name + ")") { auto clf = bayesnet::TANLd(); - clf.fit(XCont, yCont, featuresCont, classNameCont, statesCont); - auto score = clf.score(XCont, yCont); + clf.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest); + auto score = clf.score(raw.Xt, raw.yt); //scores[{file_name, "TANLd"}] = score; - REQUIRE(score == Catch::Approx(scores[{file_name, "TANLd"}]).epsilon(epsilon)); + REQUIRE(score == Catch::Approx(scores[{file_name, "TANLd"}]).epsilon(raw.epsilon)); } SECTION("Test KDB classifier (" + file_name + ")") { auto clf = bayesnet::KDB(2); - clf.fit(XDisc, yDisc, featuresDisc, classNameDisc, statesDisc); - auto score = clf.score(XDisc, yDisc); + clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv); + auto score = clf.score(raw.Xv, raw.yv); //scores[{file_name, "KDB"}] = score; REQUIRE(score == Catch::Approx(scores[{file_name, "KDB" - }]).epsilon(epsilon)); + }]).epsilon(raw.epsilon)); } SECTION("Test KDBLd classifier (" + file_name + ")") { auto clf = bayesnet::KDBLd(2); - clf.fit(XCont, yCont, featuresCont, classNameCont, statesCont); - auto score = clf.score(XCont, yCont); + clf.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest); + auto score = clf.score(raw.Xt, raw.yt); //scores[{file_name, "KDBLd"}] = score; REQUIRE(score == Catch::Approx(scores[{file_name, "KDBLd" - }]).epsilon(epsilon)); + }]).epsilon(raw.epsilon)); } SECTION("Test SPODE classifier (" + file_name + ")") { auto clf = bayesnet::SPODE(1); - clf.fit(XDisc, yDisc, featuresDisc, classNameDisc, statesDisc); - auto score = clf.score(XDisc, yDisc); + clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv); + auto score = clf.score(raw.Xv, raw.yv); // scores[{file_name, "SPODE"}] = score; - REQUIRE(score == Catch::Approx(scores[{file_name, "SPODE"}]).epsilon(epsilon)); + REQUIRE(score == Catch::Approx(scores[{file_name, "SPODE"}]).epsilon(raw.epsilon)); } SECTION("Test SPODELd classifier (" + file_name + ")") { auto clf = bayesnet::SPODELd(1); - clf.fit(XCont, yCont, featuresCont, classNameCont, statesCont); - auto score = clf.score(XCont, yCont); + clf.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest); + auto score = clf.score(raw.Xt, raw.yt); // scores[{file_name, "SPODELd"}] = score; - REQUIRE(score == Catch::Approx(scores[{file_name, "SPODELd"}]).epsilon(epsilon)); + REQUIRE(score == Catch::Approx(scores[{file_name, "SPODELd"}]).epsilon(raw.epsilon)); } SECTION("Test AODE classifier (" + file_name + ")") { auto clf = bayesnet::AODE(); - clf.fit(XDisc, yDisc, featuresDisc, classNameDisc, statesDisc); - auto score = clf.score(XDisc, yDisc); + clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv); + auto score = clf.score(raw.Xv, raw.yv); // scores[{file_name, "AODE"}] = score; - REQUIRE(score == Catch::Approx(scores[{file_name, "AODE"}]).epsilon(epsilon)); + REQUIRE(score == Catch::Approx(scores[{file_name, "AODE"}]).epsilon(raw.epsilon)); } SECTION("Test AODELd classifier (" + file_name + ")") { auto clf = bayesnet::AODELd(); - clf.fit(XCont, yCont, featuresCont, classNameCont, statesCont); - auto score = clf.score(XCont, yCont); + clf.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest); + auto score = clf.score(raw.Xt, raw.yt); // scores[{file_name, "AODELd"}] = score; - REQUIRE(score == Catch::Approx(scores[{file_name, "AODELd"}]).epsilon(epsilon)); + REQUIRE(score == Catch::Approx(scores[{file_name, "AODELd"}]).epsilon(raw.epsilon)); } SECTION("Test BoostAODE classifier (" + file_name + ")") { auto clf = bayesnet::BoostAODE(); - clf.fit(XDisc, yDisc, featuresDisc, classNameDisc, statesDisc); - auto score = clf.score(XDisc, yDisc); + clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv); + auto score = clf.score(raw.Xv, raw.yv); // scores[{file_name, "BoostAODE"}] = score; - REQUIRE(score == Catch::Approx(scores[{file_name, "BoostAODE"}]).epsilon(epsilon)); + REQUIRE(score == Catch::Approx(scores[{file_name, "BoostAODE"}]).epsilon(raw.epsilon)); } // for (auto scores : scores) { // cout << "{{\"" << scores.first.first << "\", \"" << scores.first.second << "\"}, " << scores.second << "}, "; @@ -125,10 +123,9 @@ TEST_CASE("Models features", "[BayesNet]") "sepallength -> sepalwidth", "sepalwidth [shape=circle] \n", "sepalwidth -> petalwidth", "}\n" } ); - + auto raw = RawDatasets("iris", true); auto clf = bayesnet::TAN(); - auto [XDisc, yDisc, featuresDisc, classNameDisc, statesDisc] = loadFile("iris"); - clf.fit(XDisc, yDisc, featuresDisc, classNameDisc, statesDisc); + clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv); REQUIRE(clf.getNumberOfNodes() == 6); REQUIRE(clf.getNumberOfEdges() == 7); REQUIRE(clf.show() == vector{"class -> sepallength, sepalwidth, petallength, petalwidth, ", "petallength -> sepallength, ", "petalwidth -> ", "sepallength -> sepalwidth, ", "sepalwidth -> petalwidth, "}); @@ -136,9 +133,9 @@ TEST_CASE("Models features", "[BayesNet]") } TEST_CASE("Get num features & num edges", "[BayesNet]") { - auto [XDisc, yDisc, featuresDisc, classNameDisc, statesDisc] = loadFile("iris"); + auto raw = RawDatasets("iris", true); auto clf = bayesnet::KDB(2); - clf.fit(XDisc, yDisc, featuresDisc, classNameDisc, statesDisc); + clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv); REQUIRE(clf.getNumberOfNodes() == 6); REQUIRE(clf.getNumberOfEdges() == 8); } \ No newline at end of file diff --git a/tests/TestFolding.cc b/tests/TestFolding.cc new file mode 100644 index 0000000..783258d --- /dev/null +++ b/tests/TestFolding.cc @@ -0,0 +1,98 @@ +#include +#include +#include +#include "TestUtils.h" +#include "Folding.h" + +TEST_CASE("KFold Test", "[KFold]") +{ + // Initialize a KFold object with k=5 and a seed of 19. + string file_name = GENERATE("glass", "iris", "ecoli", "diabetes"); + auto raw = RawDatasets(file_name, true); + int nFolds = 5; + platform::KFold kfold(nFolds, raw.nSamples, 19);s + int number = raw.nSamples * (kfold.getNumberOfFolds() - 1) / kfold.getNumberOfFolds(); + + SECTION("Number of Folds") + { + REQUIRE(kfold.getNumberOfFolds() == nFolds); + } + SECTION("Fold Test") + { + // Test each fold's size and contents. + for (int i = 0; i < nFolds; ++i) { + auto [train_indices, test_indices] = kfold.getFold(i); + bool result = train_indices.size() == number || train_indices.size() == number + 1; + REQUIRE(result); + REQUIRE(train_indices.size() + test_indices.size() == raw.nSamples); + } + } +} + +map counts(vector y, vector indices) +{ + map result; + for (auto i = 0; i < indices.size(); ++i) { + result[y[indices[i]]]++; + } + return result; +} + +TEST_CASE("StratifiedKFold Test", "[StratifiedKFold]") +{ + int nFolds = 3; + // Initialize a StratifiedKFold object with k=3, using the y vector, and a seed of 17. + string file_name = GENERATE("glass", "iris", "ecoli", "diabetes"); + auto raw = RawDatasets(file_name, true); + platform::StratifiedKFold stratified_kfoldt(nFolds, raw.yt, 17); + platform::StratifiedKFold stratified_kfoldv(nFolds, raw.yv, 17); + int number = raw.nSamples * (stratified_kfold.getNumberOfFolds() - 1) / stratified_kfold.getNumberOfFolds(); + + // SECTION("Number of Folds") + // { + // REQUIRE(stratified_kfold.getNumberOfFolds() == nFolds); + // } + SECTION("Fold Test") + { + // Test each fold's size and contents. + auto counts = vector(raw.classNumStates, 0); + for (int i = 0; i < nFolds; ++i) { + auto [train_indicest, test_indicest] = stratified_kfoldt.getFold(i); + auto [train_indicesv, test_indicesv] = stratified_kfoldv.getFold(i); + REQUIRE(train_indicest == train_indicesv); + REQUIRE(test_indicest == test_indicesv); + + bool result = train_indices.size() == number || train_indices.size() == number + 1; + REQUIRE(result); + REQUIRE(train_indices.size() + test_indices.size() == raw.nSamples); + auto train_t = torch::tensor(train_indices); + auto ytrain = raw.yt.index({ train_t }); + cout << "dataset=" << file_name << endl; + cout << "nSamples=" << raw.nSamples << endl;; + cout << "number=" << number << endl; + cout << "train_indices.size()=" << train_indices.size() << endl; + cout << "test_indices.size()=" << test_indices.size() << endl; + cout << "Class Name = " << raw.classNamet << endl; + cout << "Features = "; + for (const auto& item : raw.featurest) { + cout << item << ", "; + } + cout << endl; + cout << "Class States: "; + for (const auto& item : raw.statest.at(raw.classNamet)) { + cout << item << ", "; + } + cout << endl; + // Check that the class labels have been equally assign to each fold + for (const auto& idx : train_indices) { + counts[ytrain[idx].item()]++; + } + int j = 0; + for (const auto& item : counts) { + cout << "j=" << j++ << item << endl; + } + + } + REQUIRE(1 == 1); + } +} diff --git a/tests/TestUtils.h b/tests/TestUtils.h index 4849049..f442f85 100644 --- a/tests/TestUtils.h +++ b/tests/TestUtils.h @@ -14,6 +14,29 @@ pair, map> discretize(vector discretizeDataset(vector& X, mdlp::labels_t& y); tuple>, vector, vector, string, map>> loadFile(const string& name); tuple, string, map>> loadDataset(const string& name, bool class_last, bool discretize_dataset); -#endif //TEST_UTILS_H -# \ No newline at end of file +class RawDatasets { +public: + RawDatasets(const string& file_name, bool discretize) + { + // Xt can be either discretized or not + tie(Xt, yt, featurest, classNamet, statest) = loadDataset(file_name, true, discretize); + // Xv is always discretized + tie(Xv, yv, featuresv, classNamev, statesv) = loadFile(file_name); + auto yresized = torch::transpose(yt.view({ yt.size(0), 1 }), 0, 1); + dataset = torch::cat({ Xt, yresized }, 0); + nSamples = dataset.size(1); + weights = torch::full({ nSamples }, 1.0 / nSamples, torch::kDouble); + classNumStates = discretize ? statest.at(classNamet).size() : 0; + } + torch::Tensor Xt, yt, dataset, weights; + vector> Xv; + vector yv; + vector featurest, featuresv; + map> statest, statesv; + string classNamet, classNamev; + int nSamples, classNumStates; + double epsilon = 1e-5; +}; + +#endif //TEST_UTILS_H \ No newline at end of file