Highlights of interesting behaviors generated by the SMART-tiny model fine-tuned with our Closest Among Top-K (CAT-K ) rollout.
Traffic simulation aims to learn a policy for traffic agents that, when unrolled in closed-loop, faithfully recovers the joint distribution of trajectories observed in the real world. Inspired by large language models, tokenized multi-agent policies have recently become the state-of-the-art in traffic simulation. However, they are typically trained through open-loop behavior cloning, and thus suffer from covariate shift when executed in closed-loop during simulation. In this work, we present Closest Among Top-K (CAT-K) rollouts, a simple yet effective closed-loop fine-tuning strategy to mitigate covariate shift. CAT-K fine-tuning only requires existing trajectory data, without reinforcement learning or generative adversarial imitation. Concretely, CAT-K fine-tuning enables a small 7M-parameter tokenized traffic simulation policy to outperform a 102M-parameter model from the same model family, achieving the top spot on the Waymo Sim Agent Challenge leaderboard at the time of submission. The code is available at this https URL.
Closest Among Top-K (CAT-K) Rollouts unroll the policy during fine-tuning in a way that visited states remain close to the ground-truth (GT). At each time step, CAT-K first takes the top-K most likely action tokens according to the policy, then chooses the one leading to the state closest to the GT. As a result, CAT-K rollouts follow the mode of the GT (e.g., turning left), while random or top-K rollouts can lead to large deviations (e.g., going straight or right). Since the policy is essentially trained to minimize the distance between the rollout states and the GT states, the GT-based supervision remains effective for CAT-K rollouts, but not for random or top-K rollouts.
Schematic comparison of CAT-K rollout, top-K sampling, and data augmentation techniques of Trajeglish and SMART. In this example, the token vocabulary has a size of 5. We rollout three steps from t = 0 to t = 3. For CAT-K rollout and top-K sampling, the top-K is w.r.t the probabilities p of tokens predicted by the policy. For the data augmentations used by Trajeglish and SMART, the policy is unavailable, and the top-K selection is based on the negative distances between tokens and GT.
@article{zhang2024closed,
title={Closed-Loop Supervised Fine-Tuning of Tokenized Traffic Models},
author={Zhang, Zhejun and Karkus, Peter and Igl, Maximilian and Ding, Wenhao and Chen, Yuxiao and Ivanovic, Boris and Pavone, Marco},
journal={arXiv preprint arXiv:2412.05334},
year={2024}
}