diff --git a/src/simulation/scripts/lanch_one_simu.py b/src/simulation/scripts/lanch_one_simu.py index 8779f0b..2c309e8 100644 --- a/src/simulation/scripts/lanch_one_simu.py +++ b/src/simulation/scripts/lanch_one_simu.py @@ -1,35 +1,31 @@ -raise NotImplementedError("This file is currently begin worked on") - import os import sys +from typing import * +import numpy as np import onnxruntime as ort +import gymnasium as gym -from simulation import ( - VehicleEnv, -) -from simulation import config as c -from utils import onnx_utils +from simulation.config import * +from utils import run_onnx_model -# ------------------------------------------------------------------------- +from extractors import ( # noqa: F401 + CNN1DResNetExtractor, + TemporalResNetExtractor, +) +from simulation import VehicleEnv -# --- Chemin vers le fichier ONNX --- +# ------------------------------------------------------------------------- -ONNX_MODEL_PATH = "model.onnx" +ONNX_MODEL_PATH = "/home/exo/Bureau/CoVAPSy/model.onnx" -# --- Initialisation du moteur d'inférence ONNX Runtime (ORT) --- +# --- Launching of inference motor ONNX Runtime (ORT) --- def init_onnx_runtime_session(onnx_path: str) -> ort.InferenceSession: if not os.path.exists(onnx_path): - raise FileNotFoundError( - f"Le fichier ONNX est introuvable à : {onnx_path}. Veuillez l'exporter d'abord." - ) - - # Crée la session d'inférence - return ort.InferenceSession( - onnx_path - ) # On peut modifier le providers afin de mettre une CUDA + raise FileNotFoundError(f"Le fichier ONNX est introuvable à : {onnx_path}. Veuillez l'exporter d'abord.") + return ort.InferenceSession(onnx_path) if __name__ == "__main__": @@ -38,7 +34,7 @@ def init_onnx_runtime_session(onnx_path: str) -> ort.InferenceSession: os.system('if [ -n "$(ls /tmp/autotech)" ]; then rm /tmp/autotech/*; fi') - # 2. Initialisation de la session ONNX Runtime + # Starting of OnnxSession try: ort_session = init_onnx_runtime_session(ONNX_MODEL_PATH) input_name = ort_session.get_inputs()[0].name @@ -47,33 +43,48 @@ def init_onnx_runtime_session(onnx_path: str) -> ort.InferenceSession: print(f"Input Name: {input_name}, Output Name: {output_name}") except FileNotFoundError as e: print(f"ERREUR : {e}") - print( - "Veuillez vous assurer que vous avez exécuté une fois le script d'entraînement pour exporter 'model.onnx'." - ) sys.exit(1) - # 3. Boucle d'inférence (Test) env = VehicleEnv(0, 0) - obs = env.reset() + obs, _ = env.reset() + print("Début de la simulation en mode inférence...") - max_steps = 5000 step_count = 0 while True: - action = onnx_utils.run_onnx_model(ort_session, obs) + # 1. On récupère les logits (probabilités) bruts de l'ONNX + raw_action = run_onnx_model(ort_session, obs[None]) + logits = np.array(raw_action).flatten() + + # 2. On sépare le tableau en deux (Direction et Vitesse) + # On utilise n_actions_steering et n_actions_speed venant de config.py + steer_logits = logits[:n_actions_steering] + speed_logits = logits[n_actions_steering:] + + # 3. L'IA choisit l'action qui a le score (logit) le plus élevé + action_steer = np.argmax(steer_logits) + action_speed = np.argmax(speed_logits) - # 4. Exécuter l'action dans l'environnement - obs, reward, done, info = env.step(action) + # 4. On crée le tableau final parfaitement formaté pour Webots (strictement 2 entiers) + action = np.array([action_steer, action_speed], dtype=np.int64) - # Note: L'environnement Webots gère généralement son propre affichage - # env.render() # Décommenter si votre env supporte le rendu externe + # 5. Exécuter l'action dans l'environnement + next_obs, reward, done, truncated, info = env.step(action) + + step_count += 1 # Gestion des fins d'épisodes if done: print(f"Épisode(s) terminé(s) après {step_count} étapes.") - obs = env.reset() + step_count = 0 + + fresh_frame = next_obs[:, -1:] + obs, _ = env.reset() + env.context[:, -1:] = fresh_frame + obs = env.context + else: + obs = next_obs - # Fermeture propre (très important pour les processus parallèles SubprocVecEnv) - envs.close() - print("Simulation terminée. Environnements fermés.") + env.close() + print("Simulation terminée. Environnements fermés.") \ No newline at end of file diff --git a/src/simulation/src/simulation/config.py b/src/simulation/src/simulation/config.py index 1e8864d..7e572eb 100644 --- a/src/simulation/src/simulation/config.py +++ b/src/simulation/src/simulation/config.py @@ -11,7 +11,7 @@ n_map = 2 n_simulations = 1 -n_vehicles = 2 +n_vehicles = 1 n_stupid_vehicles = 0 n_actions_steering = 16 n_actions_speed = 16 @@ -26,3 +26,5 @@ LOG_LEVEL = logging.INFO FORMATTER = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + +B_DEBUG = False \ No newline at end of file diff --git a/src/simulation/src/utils/__init__.py b/src/simulation/src/utils/__init__.py index 4b7ee72..7e56893 100644 --- a/src/simulation/src/utils/__init__.py +++ b/src/simulation/src/utils/__init__.py @@ -1,3 +1,8 @@ from .plot_model_io import PlotModelIO +import onnxruntime as ort +import numpy as np __all__ = ["PlotModelIO"] + +def run_onnx_model(session: ort.InferenceSession, x: np.ndarray): + return session.run(None, {"input": x})[0] \ No newline at end of file diff --git a/uv.lock b/uv.lock index 7effdcb..7ac586b 100644 --- a/uv.lock +++ b/uv.lock @@ -852,6 +852,11 @@ dependencies = [ { name = "zmq" }, ] +[package.optional-dependencies] +controller = [ + { name = "pygame" }, +] + [package.metadata] requires-dist = [ { name = "adafruit-blinka", specifier = ">=8.0.0" }, @@ -869,6 +874,7 @@ requires-dist = [ { name = "onnxruntime", specifier = ">=1.8.0" }, { name = "opencv-python", specifier = ">=4.12.0.88" }, { name = "picamera2", specifier = ">=0.3.0" }, + { name = "pygame", marker = "extra == 'controller'", specifier = ">=2.6.1" }, { name = "pyps4controller", specifier = ">=1.2.0" }, { name = "rpi-gpio", specifier = ">=0.7.1" }, { name = "rpi-hardware-pwm", specifier = ">=0.1.0" }, @@ -879,6 +885,7 @@ requires-dist = [ { name = "websockets", specifier = ">=16.0" }, { name = "zmq", specifier = ">=0.0.0" }, ] +provides-extras = ["controller"] [[package]] name = "humanfriendly" @@ -2392,6 +2399,9 @@ dependencies = [ wheels = [ { url = "https://files.pythonhosted.org/packages/d3/54/a2ba279afcca44bbd320d4e73675b282fcee3d81400ea1b53934efca6462/torch-2.10.0-2-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:13ec4add8c3faaed8d13e0574f5cd4a323c11655546f91fbe6afa77b57423574", size = 79498202, upload-time = "2026-02-10T21:44:52.603Z" }, { url = "https://files.pythonhosted.org/packages/ec/23/2c9fe0c9c27f7f6cb865abcea8a4568f29f00acaeadfc6a37f6801f84cb4/torch-2.10.0-2-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:e521c9f030a3774ed770a9c011751fb47c4d12029a3d6522116e48431f2ff89e", size = 79498254, upload-time = "2026-02-10T21:44:44.095Z" }, + { url = "https://files.pythonhosted.org/packages/b3/7a/abada41517ce0011775f0f4eacc79659bc9bc6c361e6bfe6f7052a6b9363/torch-2.10.0-3-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:98c01b8bb5e3240426dcde1446eed6f40c778091c8544767ef1168fc663a05a6", size = 915622781, upload-time = "2026-03-11T14:17:11.354Z" }, + { url = "https://files.pythonhosted.org/packages/ab/c6/4dfe238342ffdcec5aef1c96c457548762d33c40b45a1ab7033bb26d2ff2/torch-2.10.0-3-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:80b1b5bfe38eb0e9f5ff09f206dcac0a87aadd084230d4a36eea5ec5232c115b", size = 915627275, upload-time = "2026-03-11T14:16:11.325Z" }, + { url = "https://files.pythonhosted.org/packages/d8/f0/72bf18847f58f877a6a8acf60614b14935e2f156d942483af1ffc081aea0/torch-2.10.0-3-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:46b3574d93a2a8134b3f5475cfb98e2eb46771794c57015f6ad1fb795ec25e49", size = 915523474, upload-time = "2026-03-11T14:17:44.422Z" }, { url = "https://files.pythonhosted.org/packages/cc/af/758e242e9102e9988969b5e621d41f36b8f258bb4a099109b7a4b4b50ea4/torch-2.10.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:5fd4117d89ffd47e3dcc71e71a22efac24828ad781c7e46aaaf56bf7f2796acf", size = 145996088, upload-time = "2026-01-21T16:24:44.171Z" }, { url = "https://files.pythonhosted.org/packages/23/8e/3c74db5e53bff7ed9e34c8123e6a8bfef718b2450c35eefab85bb4a7e270/torch-2.10.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:787124e7db3b379d4f1ed54dd12ae7c741c16a4d29b49c0226a89bea50923ffb", size = 915711952, upload-time = "2026-01-21T16:23:53.503Z" }, { url = "https://files.pythonhosted.org/packages/6e/01/624c4324ca01f66ae4c7cd1b74eb16fb52596dce66dbe51eff95ef9e7a4c/torch-2.10.0-cp312-cp312-win_amd64.whl", hash = "sha256:2c66c61f44c5f903046cc696d088e21062644cbe541c7f1c4eaae88b2ad23547", size = 113757972, upload-time = "2026-01-21T16:24:39.516Z" },