168 lines
5.8 KiB
Python
168 lines
5.8 KiB
Python
#%%
|
|
import os
|
|
from datetime import time
|
|
|
|
from sklearn.ensemble import RandomForestClassifier
|
|
from sklearn.model_selection import cross_validate, StratifiedKFold, train_test_split
|
|
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report
|
|
import pandas as pd
|
|
import numpy as np
|
|
import socket
|
|
import json
|
|
import pandas as pd
|
|
random_state = 19
|
|
splits = 5
|
|
estimators = 100
|
|
#%% md
|
|
# # Socket que funciona*
|
|
#%% md
|
|
# # Data Preparation
|
|
#%%
|
|
dataset = pd.read_csv(os.path.join("csv", 'openbci.csv'))
|
|
|
|
columnas_utiles = [f"EXG Channel {i}" for i in range(16)] + ["Movement"]
|
|
dataset = dataset[columnas_utiles]
|
|
|
|
# Split the dataset
|
|
X = dataset.iloc[:, :-1].values
|
|
# Guardar nombres de columnas usadas en el modelo
|
|
columnas_modelo = dataset.columns[:-1].tolist()
|
|
print("✔️ Columnas usadas en el modelo:", columnas_modelo)
|
|
|
|
y = dataset.iloc[:,-1].values
|
|
# Show the characteristics of the dataset
|
|
print(f"X shape: {X.shape}")
|
|
print(f"y shape: {y.shape}")
|
|
print("Labels distribution")
|
|
print("===================")
|
|
print(dataset.Movement.value_counts(normalize=True)*100)
|
|
|
|
#%% md
|
|
# # Stratified Cross Validation
|
|
#%%
|
|
import joblib
|
|
clf = RandomForestClassifier(n_estimators=estimators, n_jobs=-1, random_state=random_state)
|
|
cv = StratifiedKFold(n_splits=splits, shuffle=True, random_state=random_state)
|
|
scores = cross_validate(clf, X, y, scoring="accuracy", cv=cv, n_jobs=-1, return_train_score=True)
|
|
print(f"Accuracy: {np.mean(scores['test_score'])} (+/- {np.std(scores['test_score'])})")
|
|
clf.fit(X, y)
|
|
joblib.dump(clf, 'modelo_entrenado.pkl')
|
|
#%% md
|
|
# # Live Classification
|
|
#%%
|
|
import socket
|
|
import json
|
|
import numpy as np
|
|
import joblib
|
|
import csv
|
|
import os
|
|
from datetime import datetime
|
|
|
|
clf = joblib.load('modelo_entrenado.pkl')
|
|
|
|
# Variables globales
|
|
IP = "127.0.0.1"
|
|
PORT = 12345
|
|
|
|
# Crear socket UDP
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
sock.bind((IP, PORT))
|
|
print(f"Esperando datos en {IP}:{PORT}...")
|
|
|
|
# Crear nombre único para el archivo de salida
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
output_file = f"samples_con_predicciones_{timestamp}.csv"
|
|
|
|
while True:
|
|
#data, _ = sock.recvfrom(65535)
|
|
try:
|
|
mensaje = json.loads(data.decode("utf-8"))
|
|
datos_crudos = []
|
|
for linea in mensaje:
|
|
datos_crudos.append(linea["data"])
|
|
|
|
datos_crudos = mensaje["data"]
|
|
|
|
# Convertir a numpy y verificar dimensiones
|
|
datos_np = np.array(datos_crudos)
|
|
if datos_np.shape[0] != 16:
|
|
print(f"Se esperaban {16} canales, se recibieron {datos_np.shape[0]}")
|
|
continue
|
|
|
|
# [
|
|
# [ 1, 2, 3, 4 .... 16],
|
|
# [ 182,3,34,4,.... 21],
|
|
# ...
|
|
# [sdafl asdfklalsdkfjd]
|
|
# Vamos a hacer un dataset con las 5 muestras que envía el sensor
|
|
samples = []
|
|
for i in range(5):
|
|
samples.append([datos_np[j][i] for j in range(16)])
|
|
|
|
# Clasificamos por mayoría
|
|
pred = clf.predict(samples)
|
|
# [ i, d, i, i, none]
|
|
keys, values = np.unique(pred, return_counts=True)
|
|
res = keys[np.argmax(values)]
|
|
rep = max(values)
|
|
print(f"Predicción: {res} with a probability of {rep/5*100}%")
|
|
|
|
except Exception as e:
|
|
print("Error procesando mensaje:", e)
|
|
|
|
#%%
|
|
# import numpy as np
|
|
#
|
|
# a= [[11284.547,12447.82,12083.107,11501.806,12807.84],[-12317.487,-11649.863,-12289.168,-11944.437,-11870.162],[-44142.168,-43465.47,-44276.66,-43668.535,-43838.836],[-48939.258,-49422.3,-49312.082,-49024.46,-49621.812],[-46515.21,-42789.35,-45057.406,-45182.152,-42718.074],[-50600.24,-49273.01,-48712.496,-50988.35,-47901.44],[-55553.695,-54737.543,-55705.754,-54981.83,-55189.562],[-47437.688,-47808.816,-47803.074,-47444.28,-48032.02],[-32000.01,-31640.191,-32217.111,-31654.541,-31996.723],[-34022.508,-33585.953,-34134.49,-33706.766,-33866.96],[-52641.195,-52273.914,-52842.473,-52302.656,-52619.94],[-37791.434,-37473.55,-38074.094,-37444.223,-37881.824],[-50217.285,-49953.332,-50541.453,-49866.676,-50376.137],[-54649.746,-54273.344,-54846.934,-54319.88,-54617.36],[-36739.273,-36184.703,-36606.793,-36466.848,-36266.13],[-63696.46,-67149.516,-62522.164,-66528.56,-64694.105]]
|
|
# samples = []
|
|
# for i in range(5):
|
|
# samples.append([a[j][i] for j in range(16)])
|
|
# print(np.array(samples))
|
|
# a = np.array(["i", "i", "d", "d", "d", "none"]).reshape(-1, 1)
|
|
# print(a)
|
|
# c, k = np.unique(a, return_counts=True)
|
|
# print(c, k)
|
|
# #%% md
|
|
# # # Classification Report
|
|
# #%%
|
|
# X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=random_state, stratify=y)
|
|
# clf.fit(X_train, y_train)
|
|
#
|
|
# y_pred = clf.predict(X_test)
|
|
# print(classification_report(y_test, y_pred, digits=7))
|
|
# #%% md
|
|
# # # Confusion Matrix
|
|
# #%%
|
|
# cm = confusion_matrix(y_test, y_pred)
|
|
# disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=clf.classes_)
|
|
# disp.plot()
|
|
# #%%
|
|
# import socket
|
|
# import datetime
|
|
#
|
|
# # 📡 Configuración del Servidor UDP para recibir datos de OpenBCI
|
|
# UDP_IP = "127.0.0.1" # Dirección donde OpenBCI está enviando los datos
|
|
# UDP_PORT = 12345 # Puerto configurado en OpenBCI GUI
|
|
# BUFFER_SIZE = 4096 # Asegurar espacio suficiente para recibir datos
|
|
#
|
|
# # 🟢 Crear socket UDP
|
|
# sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
# sock.bind((UDP_IP, UDP_PORT))
|
|
#
|
|
# print(f"🔄 Servidor UDP en {UDP_IP}:{UDP_PORT}, esperando datos de OpenBCI...")
|
|
#
|
|
# while True:
|
|
# try:
|
|
# # 🟡 Recibir datos UDP desde OpenBCI
|
|
# data, addr = sock.recvfrom(BUFFER_SIZE)
|
|
# decoded_data = data.decode("utf-8").strip()
|
|
#
|
|
# # 🟠 Mostrar en consola los datos recibidos
|
|
# print(f"{decoded_data} Time: {datetime.datetime.now()}")
|
|
#
|
|
# except socket.timeout:
|
|
# print("⏳ No hay datos recibidos aún... (Verifica la configuración de OpenBCI)")
|
|
#
|
|
# except Exception as e:
|
|
# print(f"❌ Error al recibir datos: {e}")
|