Vanishing Gradients in Reinforcement Finetuning of Language Models
AuthorsNoam Razin, Hattie Zhou, Preetum Nakkilan, Josh Susskind, Omid Saremi, Arwen Bradley, Vimal Thilak, Etai Littwin
AuthorsNoam Razin, Hattie Zhou, Preetum Nakkilan, Josh Susskind, Omid Saremi, Arwen Bradley, Vimal Thilak, Etai Littwin
Pretrained language models are commonly adapted to comply with human intent and downstream tasks via finetuning. The finetuning process involves supervised finetuning (SFT), using labeled samples, and/or reinforcement learning based fine-tuning (RFT) via policy gradient methods, using a (possibly learned) reward function. This work highlights an overlooked optimization hurdle in RFT: we prove that the expected gradient for an input sample (i.e. prompt) vanishes if its reward standard deviation under the model is low, regardless of whether the reward mean is near-optimal or not. We then demonstrate the prevalence and detrimental effects of vanishing gradients due to low reward standard deviation in an RFT benchmark for language models. In particular, we show that in datasets where samples with low reward standard deviation under the pretrained model are more prevalent, the reward that RFT achieves compared to SFT is worse. Controlled experiments and a theoretical analysis further establish that, even in simplified settings, vanishing gradients in RFT can lead to extremely slow convergence. Lastly, we explore ways to overcome vanishing gradients in RFT of language models. We find the common practice of an initial SFT phase to be the most promising candidate, which sheds light on its importance in an RFT pipeline. Furthermore, our experiments reveal that a relatively few number of optimization steps of SFT on a small number of labeled samples suffice, implying that the initial SFT phase need not be expensive in terms of compute and data labeling efforts