3032 Calibrating Transformers Via S
3032 Calibrating Transformers Via S
A BSTRACT
1 I NTRODUCTION
Significant improvements have been made for accuracies in prediction tasks for computer vision,
speech recognition and natural language processing using deep learning (He et al., 2015; Graves
et al., 2013; Vaswani et al., 2017). In particular, Transformers (Vaswani et al., 2017) based on multi-
head attention (MHA) have gained popularity in recent years. With Transformers being deployed
in many downstream applications (Vaswani et al., 2017; Dosovitskiy et al., 2021; Brown et al.,
2020), it is crucial to prevent poor robustness which often comes from erratic outputs with high
confidence from these models (Guo et al., 2017b; Mukhoti et al., 2020). This requires calibrated
uncertainty quantification for Transformers which is much less well-studied at the time of this work,
and it raises concerns about using Transformers for safety-critical tasks which require rational and
risk-averse decision making under uncertainty.
Regarding uncertainty quantification, Bayesian inference is a powerful and principled framework
to build probabilistic models for rational prediction and decision-making under uncertainty (Gal,
2016). Significant progress is observed for applying (approximate) Bayesian inference methods
to quantify uncertainty in fully-connected, convolutional and recurrent neural networks (Blundell
et al., 2015; Gal & Ghahramani, 2016; Zhang et al., 2019; Ritter et al., 2021). Initial efforts have
been made on extending these techniques to Transformers but with mixed results (Tran et al., 2019;
Xue et al., 2021).On the other hand, Gaussian processes (GPs) are gold standard methods for tasks
requiring reliable function-space uncertainty estimates (Rasmussen & Williams, 2006; Wilson et al.,
2020). Researchers have proposed to integrate deep learning ideas to GP model design, including
deep kernel learning (Wilson et al., 2016) and deep GPs (Damianou & Lawrence, 2013; Salimbeni
& Deisenroth, 2017). Still these models have yet to be scaled to modern deep learning tasks such as
large-scale image classification and language modelling.
In this work, we propose sparse Gaussian process attention (SGPA), a novel uncertainty quantifica-
tion technique for attention-based models (e.g., Transformers), by leveraging techniques from sparse
variational Gaussian processes (SVGP) (Snelson & Ghahramani, 2005; Hensman et al., 2013) for
improved uncertainty estimates. Our work presents the following insights and contributions:
• Our key observation is that kernel-based attention (Tsai et al., 2019) is equivalent to the posterior
mean of an SVGP. This inspires us to extend SVGP to Transformers for uncertainty estimation.
1
Published as a conference paper at ICLR 2023
equivalent equivalent
to kernel Memory to kernel
Add & Norm attention inefficient attention Memory
Feed Forward reduction
output from
output from previous layer output from previous layer previous layer
(a) Vanilla Transformer (b) Standard SGPA (ours) (c) Decoupled SGPA (ours)
Figure 1: Illustration of one head (h) of multi-head self attention in one layer of (a) vanilla Trans-
former, (b) Transformer based on standard SGPA and (c) Transformer based on decoupled SGPA.
The resulting Transformer based on our SGPA approach can be viewed as a sparse deep GP
(Salimbeni & Deisenroth, 2017) with deep kernel in use for each GP layers.
• We address the computational inefficiency issues of a naive extension of SVGP to multi-head
self-attention with decoupled inducing points techniques (Salimbeni et al., 2018), making SPGA
scalable to deep learning tasks that Transformers are applied to.
• Empirically, on a variety of vision, NLP and graph prediction tasks and compared with base-
lines, SGPA-based Transformers improve considerably over in-distribution calibration, out-
of-distribution (OOD) robustness, and OOD detection, while achieving competitive accuracy
against Transformers with standard (Vaswani et al., 2017) or kernel attention (Tsai et al., 2019).
2 BACKGROUND
Attention mechanism, first introduced in Graves et al. (2013), has become the core bulding block
for Transformer models. In this work, we consider Transformers using multi-head self-attention
(MHSA) as in Vaswani et al. (2017); Dosovitskiy et al. (2021). Here, we briefly review MHSA and
sparse variational Gaussian process, based on which our method is developed.
2
Published as a conference paper at ICLR 2023
f is specified with a mean function (often set to zero) and a covariance function parameterized by a
kernel function Kψ (·, ·) (with hyperparameters ψ). Specifically, the marginal distribution of function
values f evaluated on any finite number of inputs X = [x1 , · · ·, xN ]⊤ , xn ∈ X is Gaussian:
Prior: p(f ) ∼ GP(0, kψ (·, ·)) ⇒ p(f |X) = N (0, KXX ), [KXX ]i,j = Kψ (xi , xj ). (3)
Given training data (X, y), with a Gaussian likelihood p(y|f ) = N (f , σ 2 I), the posterior process
is also a GP, and the posterior predictive distribution of f ∗ evaluated at the test inputs X ∗ is:
p(f ∗ |X ∗ , X, y) = N (KX ∗ X (KXX + σ 2 I)−1 y, KX ∗ X ∗ − KX ∗ X (KXX + σ 2 I)−1 KXX ∗ ). (4)
Unfortunately, with non-Gaussian likelihoods (e.g., for classification) or when the number of
training datapoints N is large, the posterior process is intractable. Still we can approximate
the posterior process with a GP, and a popular approach is sparse variational Gaussian process
(SVGP) (Titsias, 2009; Hensman et al., 2013), which uses a small number of M inducing points
(Z, u) = {(zm , um )}Mm=1 to summarise the training data and, to some degree, replaces the terms
involving X, y in eq.(4) with the inducing points.
A detailed introduction of SVGP is provided in Appendix B.1, in short it utilises the property of GP
to augment the prior as p(f , u|X, Z), which is a Gaussian with zero mean and covariance matrix
as a kernel matrix computed on [X, Z], and define the approximate posterior process as:
p(f ∗ , f , u|Z, X ∗ , X, y) ∝ p(y|f )p(f ∗ , f |u, Z, X ∗ , X)p(u|Z)
(5)
≈q(f ∗ , f , u|Z, X ∗ , X) := p(f ∗ , f |u, Z, X ∗ , X)q(u), q(u) := N (mu , Su ).
Notice that the exact posterior and the approximate posterior share the conditional distribution
p(f ∗ , f |u, Z, X ∗ , X). This simplifies the evidence lower-bound (ELBO) objective for optimis-
ing the variational parameters mu , Su and the kernel hyperparameters ψ to
LELBO = Eq(f |X,Z) [log p(y|f )] − KL(q(u)||p(u|Z)). (6)
Since q(u) and p(u|Z) are both Gaussian, the second term can be evaluated analytically. For non-
Gaussian likelihoods, we resort to Monte-Carlo estimation for computing the first term. In predic-
tion, the approximate posterior predictive distribution of f ∗ evaluated on test inputs X ∗ becomes:
Z
q(f |X , Z) = p(f ∗ , f |u, Z, X ∗ , X)q(u)dudf
∗ ∗
(7)
−1 −1 −1
= N (KX ∗ Z KZZ mu , KX ∗ X ∗ + KX ∗ Z KZZ (Su − KZZ )KZZ KZX ∗ ).
Note that the computations of both the ELBO (eq.(6)) and the approximate posterior predictive
distribution (eq.(7)) require matrix inversion of KZZ only. Since we usually use a small number
of inducing points (M ≪ N ), the computational cost of SVGP (O(N M 2 + M 3 )) is significantly
lower than the O(N 3 ) cost in full GP resulting from the inversion of KXX (c.f. eq.(4)).
One way to take advantage of the expressiveness of DNN in GP is to parameterize the kernel func-
tion using DNN, so that the network weights become part of the hyperparameters of a deep kernel
(Wilson et al., 2016). Given a regular base kernel, such as RBF kernel Kbase (·, ·), we can first map
the inputs X to a feature space using a DNN, hθ (X), then apply the base kernel to the DNN features
corresponding to the inputs: Kdeep (·, ·) = Kbase (hθ (·), hθ (·)).
3
Published as a conference paper at ICLR 2023
where dk is the dimension of keys. Since attention involves measuring the similarity between q and
⊤
k, Tsai et al. (2019) generalised SDP-Attention by replacing softmax( qk
√ ) in eq.(8) with a kernel
dk
gram matrix Kqk ([Kqk ]i,j = K(qi , kj )) computed using a valid symmetric kernel K(·, ·), for
which we refer to it as kernel attention or K-Attention for short:
K-Attention: F = Kqk v. (9)
−1
Recall the posterior mean of SVGP in eq.(7) is m = KXZ KZZ mu when evaluated on training
inputs (X ∗ = X). Now we reparameterise the variational mean parameter of SVGP as [v]:,d :=
−1
KZZ mu for each dimension (d) of v, and define the queries and keys as the input locations and
inducing point locations: q := X, k := Z. By doing so, equivalence can be identified between the
posterior mean of an SVGP and each dimension of the output of a kernel attention block.This allows
us to extend the toolbox of Gaussian processes and their scalable approximations for quantifying
uncertainty in Transformers in the following sections.
4
Published as a conference paper at ICLR 2023
Compared to standard SGPA (eq.(10), where kah in decoupled SGPA is the same as kh in stan-
dard SGPA), we see that the posterior mean of decoupled SGPA also involves two extra terms to
take into account the effect of global inducing points. But more importantly, the posterior vari-
ance of the two SGPA methods differ only in the keys/inducing inputs in use (input-dependent
keys kh versus global keys kgh ), and this brings in the key advantange of decoupled SGPA. As
the posterior covariance in eq.(11) only involves the global inducing points, the variational co-
variance no longer needs to be input-dependent, and (the Cholesky factor of) Sgh can be param-
eterised freely. Now the number of parameters for the covariance part is of order of O(Mg2 )
(vs O(T 2 ) in standard SPGA), and the computation of matrix inversion pays a one-off cost of
O(Mg3 ) (vs O(T 3 ) for every input sequence). Notice that we are free to choose the number of
T
global inducing points Mg , and in practice we find Mg = O( avg H ) is usually sufficient, where
Tavg is the average length of training input sequences. In Table 1, we summarise time complex-
ity (with batch size B) and the additional memory (number of parameters) required for SGPA
in one head of a Transformer. We also include maximum likelihood estimation (MLE) for ref-
erence (note that memory complexity for MLE does not depend on input sequence length T ).
As the time and mem-
ory savings are signifi- Table 1: Complexity comparison for standard and decoupled SGPA.
cant, we mainly evalu-
Model Time Additional Memory
ate decoupled SGPA in 2
our experiments, and in MLE O(BT ) -
3 2
the rest of the main text Standard SGPA O(BT ) O(T )
2 3 2
we will refer to decoupled Decoupled SGPA O(BT M g + M g ) O(M g)
SGPA as SGPA for short.
5
Published as a conference paper at ICLR 2023
L X
X H (14)
l,h
− Eq(F l |F 0 ,{kgj,h }l,H l,h
)) [KL(q(ua∪g |kg , F
l−1
)||p(ul,h l,h
a∪g |kg , F
l−1
))].
j=1,h=1
l=1 h=1
In practice, we resort to Monte-Carlo to estimate LELBO with samples of function values generated
iteratively passing through each layer using the reparameterization trick (eq.(13)).
4 E XPERIMENTS
We evaluate SGPA on prediction tasks across modalities, with the following experimental set-up.
• Datasets: CIFAR10 & CIFAR100 (image classification (Krizhevsky et al., 2009), CV tasks);
CoLA (linguistic acceptability prediction (Warstadt et al., 2019), NLP task) and IMDB (senti-
ment analysis, (Maas et al., 2011), NLP task).
• Network architectures: We use Vision Transformers (ViT (Dosovitskiy et al., 2021)) for CV
tasks. For kernel attention we use the exponential kernel (Tsai et al., 2019) and the ARD-RBF
kernel (Rasmussen & Williams, 2006) for NLP and CV tasks respectively. Scaled dot-product
(SDP) attention based Transformers are also evaluated. As in Tsai et al. (2019), we find kernel
attention tends to outperform SDP attention in most tasks considered, thus we do not include the
results of SDP attention in the main text. These results can be found in the tables in Appendix
G.
• Baselines: We compare our approach with the following “single-model” methods: maximum
likelihood estimation (MLE), Bayesian inference methods including mean-field variational infer-
ence (MFVI, (Blundell et al., 2015)), Monte-Carlo Dropout (MCD, (Gal & Ghahramani, 2016)),
Kronecker-factored last layer Laplace approximation (KFLLLA) (Kristiadi et al., 2020), and
Spectral-normalized Neural Gaussian Process (SNGP) (Liu et al., 2020). For tasks where a val-
idation set is used, we also consider temperature scaling (TS) (Guo et al., 2017a) and use the
validation set as the calibration set. For CV tasks, we also consider ensemble methods: we com-
pare SGPA ensemble (SGPAE) with deep ensemble (DE) (Lakshminarayanan et al., 2017). We
don’t consider ensemble models in NLP tasks since we use different train-(valid)-test splits in
different runs for them.
• Evaluations & metrics: We consider three evaluation set-ups: in-distribution performance, out-
of-distribution (OOD) robustness and OOD detection. The metrics on test set include predictive
accuracy metrics for each task, uncertainty calibration metrics such as negative predictive log-
likelihood (NLL), expected calibration error (ECE) and maximum calibration error (MCE) (Guo
et al., 2017b). We report the mean±two standard errors for each metric obtained from 5 inde-
pendent runs. For OOD detection tasks we consider the area under the ROC & precision-recall
curves (AUROC & AUPR, respectively), and we report the average ranks in terms of AUROC
and AUPR over all of the 6 OOD detection tasks for each method.
For fair comparisons, within each task, all the models are trained using the same architecture and
optimisation setting. All the models are trained from scratch without pre-training. We include the
experimental details in Appendix E. Results in tables are also presented in Appendix G.
We report the evaluation results for in-distribution test data on image classification (CIFAR10 &
CIFAR100, without data augmentation), sentiment analysis (IMDB), and linguistic acceptability
(CoLA) tasks in the first, second, third and fourth row of Figure 2 respectively. Here for the CoLA
dataset, predictive accuracy is measured by Matthew correlation coefficient (MCC) (Matthews,
1975) instead of accuracy, as in Warstadt et al. (2019).
All “single-model” calibration methods considered tend to improve the calibration, except for sen-
timent analysis, where KFLLLA fails in the sense that it achieves worse calibration even than MLE
(although KFLLLA achieves best calibration for linguistic acceptability (CoLA), its performance is
unstable across tasks). Although MFVI tends to achieve the lowest calibration errors, it severely un-
derfits the data in all the experiments. This is undesirable, as improvement in calibration should not
come at a price of noticeable drop in predictive correctness. As a counter example, one can achieve
6
Published as a conference paper at ICLR 2023