1. Mitigating Forgetting in LLM Supervised Fine-Tuning and Preference Learning
- Author
-
Fernando, Heshan, Shen, Han, Ram, Parikshit, Zhou, Yi, Samulowitz, Horst, Baracaldo, Nathalie, and Chen, Tianyi
- Subjects
Computer Science - Machine Learning ,Computer Science - Artificial Intelligence ,Computer Science - Computation and Language ,Mathematics - Optimization and Control ,Statistics - Machine Learning - Abstract
Post-training of pre-trained LLMs, which typically consists of the supervised fine-tuning (SFT) stage and the preference learning (RLHF or DPO) stage, is crucial to effective and safe LLM applications. The widely adopted approach in post-training popular open-source LLMs is to sequentially perform SFT and RLHF/DPO. However, sequential training is sub-optimal in terms of SFT and RLHF/DPO trade-off: the LLM gradually forgets about the first stage's training when undergoing the second stage's training. We theoretically prove the sub-optimality of sequential post-training. Furthermore, we propose a practical joint post-training framework with theoretical convergence guarantees and empirically outperforms sequential post-training framework, while having similar computational cost. Our code is available at https://github.com/heshandevaka/XRIGHT.
- Published
- 2024