Training recurrent neural networks via forward propagation through time
OA Version
Citation
Anil Kag, Venkatesh Saligrama Proceedings of the 38th International Conference on Machine Learning, PMLR 139:5189-5200, 2021.
Abstract
Back-propagation through time (BPTT) has been
widely used for training Recurrent Neural Networks
(RNNs). BPTT updates RNN parameters
on an instance by back-propagating the error
in time over the entire sequence length, and
as a result, leads to poor trainability due to the
well-known gradient explosion/decay phenomena.
While a number of prior works have proposed to
mitigate vanishing/explosion effect through careful
RNN architecture design, these RNN variants
still train with BPTT.We propose a novel forwardpropagation
algorithm, FPTT , where at each time,
for an instance, we update RNN parameters by
optimizing an instantaneous risk function. Our
proposed risk is a regularization penalty at time
t that evolves dynamically based on previously
observed losses, and allows for RNN parameter
updates to converge to a stationary solution
of the empirical RNN objective. We consider
both sequence-to-sequence as well as terminal
loss problems. Empirically FPTT outperforms
BPTT on a number of well-known benchmark
tasks, thus enabling architectures like LSTMs to
solve long range dependencies problems.
Description
License
Copyright © The authors and PMLR 2022