Parallel PPO with discrete actions#
Like in the Distributed-SAC notebook, the agents perceive the other agent’s relative position and speed.
Instead of continous accelerations, we train a policy to compute discrete accelerations using parallel PPO, as SAC from StableBaseline only support continuous actions.
[1]:
from navground import core, sim
from navground.learning import DefaultObservationConfig
from navground.learning.config import DiscreteControlActionConfig
from navground.learning.parallel_env import make_vec_from_penv
from navground.learning.examples.pad import get_env, marker, neighbor
from stable_baselines3.common.vec_env import VecMonitor
name = "DistributedDiscrete"
action = DiscreteControlActionConfig(use_acceleration_action=True, max_acceleration=1, fix_orientation=True)
observation = DefaultObservationConfig(flat=False, include_velocity=True, include_target_direction=False)
sensors = [marker(), neighbor()]
train_env = get_env(action=action, observation=observation,
sensors=sensors, start_in_opposite_sides=True)
train_venv = VecMonitor(make_vec_from_penv(train_env, num_envs=4))
test_env = get_env(action=action, observation=observation,
sensors=sensors, start_in_opposite_sides=True)
test_venv = VecMonitor(make_vec_from_penv(test_env, num_envs=4))
We removed the neighbor position from the observation:
[2]:
train_venv.observation_space
[2]:
Dict('neighbor/position': Box(-10.0, 10.0, (1, 1), float32), 'neighbor/velocity': Box(-0.166, 0.166, (1, 1), float32), 'pad/x': Box(-1.0, 1.0, (1,), float32), 'ego_velocity': Box(-0.14, 0.14, (1,), float32))
[3]:
train_venv.action_space
[3]:
Discrete(3)
Training#
You can skip training and instead load the last trained policy by changing the flag below.
[4]:
from navground.learning.utils.jupyter import skip_if, run_if
training = True
[5]:
%%skip_if $training
import pathlib, os
from stable_baselines3 import PPO
log = max(pathlib.Path(f'logs/{name}/SAC').glob('*'), key=os.path.getmtime)
[6]:
%%run_if $training
from datetime import datetime as dt
from stable_baselines3 import PPO
from stable_baselines3.common.logger import configure
from navground.learning.utils.sb3 import callbacks
from navground.learning.scenarios.pad import render_kwargs
model = PPO("MultiInputPolicy", train_venv, verbose=0, n_steps=1024)
stamp = dt.now().strftime("%Y%m%d_%H%M%S")
log = f"logs/{name}/SAC/{stamp}"
model.set_logger(configure(log, ["csv", "tensorboard"]))
cbs = callbacks(venv=test_venv, best_model_save_path=log,
eval_freq=1024 * 4, export_to_onnx=True,
**render_kwargs())
log
[6]:
'logs/DistributedDiscrete/SAC/20250521_131506'
[7]:
%%run_if $training
model.learn(total_timesteps=1_000_000, reset_num_timesteps=False,
log_interval=10, callback=cbs)
model.num_timesteps
[7]:
1007616
[8]:
from navground.learning.utils.sb3 import plot_eval_logs
plot_eval_logs(log, reward_low=-200, reward_high=0, success=True)

Evaluation#
[9]:
from stable_baselines3.common.evaluation import evaluate_policy
best_model = PPO.load(f'{log}/best_model')
evaluate_policy(best_model.policy, test_venv, n_eval_episodes=30)
[9]:
(-54.59525, 114.69993)
[10]:
from navground.learning.evaluation.video import display_episode_video
display_episode_video(test_env, policy=best_model.policy, factor=4, seed=1, **render_kwargs())
[10]:
[11]:
from navground.learning.evaluation.video import record_episode_video
record_episode_video(test_env, policy=best_model.policy,
path=f'../videos/{name}.mp4', seed=1, **render_kwargs())
[12]:
from navground.learning.utils.plot import plot_policy
plot_policy(best_model.policy,
variable={'pad/x': (-1, 1), 'neighbor/position': (-2, 2)},
fix={'ego_velocity': 0.1, 'neighbor/velocity': 0.1},
actions={0: 'acceleration'}, width=5, height=3)
