Training recurrent neural networks via forward propagation through time

Files
Date
2021-06-07
DOI
Authors
Saligrama, Venkatesh
Kag, Anil
Version
OA Version
Citation
Anil Kag, Venkatesh Saligrama Proceedings of the 38th International Conference on Machine Learning, PMLR 139:5189-5200, 2021.
Abstract
Back-propagation through time (BPTT) has been widely used for training Recurrent Neural Networks (RNNs). BPTT updates RNN parameters on an instance by back-propagating the error in time over the entire sequence length, and as a result, leads to poor trainability due to the well-known gradient explosion/decay phenomena. While a number of prior works have proposed to mitigate vanishing/explosion effect through careful RNN architecture design, these RNN variants still train with BPTT.We propose a novel forwardpropagation algorithm, FPTT , where at each time, for an instance, we update RNN parameters by optimizing an instantaneous risk function. Our proposed risk is a regularization penalty at time t that evolves dynamically based on previously observed losses, and allows for RNN parameter updates to converge to a stationary solution of the empirical RNN objective. We consider both sequence-to-sequence as well as terminal loss problems. Empirically FPTT outperforms BPTT on a number of well-known benchmark tasks, thus enabling architectures like LSTMs to solve long range dependencies problems.
Description
License
Copyright © The authors and PMLR 2022