In this short post, I'll explain a practical problem you'll encounter when visualizing the gradients of decoder layers, and how to resolve it.
The Llama 3.2-1b
model consists of a token input embedding layer, 15 stacked decoder layers, followed by a token prediction head.
Each decoder layer takes as input a hidden state tensor of dimension (B,N,2048)
, where B is the batch dimension, N is the number of tokens, and 2048 is the model dimension. Therefore H[0,1,10]
represents the activation of the 11th "neuron" of the second token in a batch size of one.
PyTorch's autograd
allows us to compute the partial differential of any activation (call it
) with respect to some earlier layer of the network - say, for layer j,
:
tokenized = tokenizer(input_text, return_tensors="pt", padding=True).to(device)
result = model(**tokenized, output_hidden_states=True)
hidden_states = result.hidden_states
gradient = torch.autograd.grad(
outputs=alpha,
inputs=hidden_states[j],
)[0]
torch.autograd.grad
only allows for the computation with respect to a single activation, so for the complete picture we would need to loop over every activation in the next layer
:
gradients = {}
token_idx = ...
for i in range(2048):
alpha = hidden_states[j+1][:, token_idx, i]
gradients[i] = torch.autograd.grad(
outputs=alpha,
inputs=hidden_states[j],
)[0]
As a heatmap, the result looks something like this (this is called the Jacobian):
Notice the strong diagonal? This is not a bug. It's because of skip connections.
Skip Connections
Skip connections are ubiquitous in LLM architectures, as they are one of the primary reasons that it is possible to train very deep networks. If we are fitting a layer
, where H(x)
is the main function (i.e.; self-attention), then
can be interpreted as the residual
. Thus each layer is responsible for learning an additive update to the input x instead of a wholesale transformation. This makes deep networks easier to train, as in some sense fitting the residual is "easier". It also appears to have some kind of strange smoothing effect on the loss landscape that is no doubt related. You can read the ResNet paper for more details.
However, this is problematic for our purposes because we can't read the "real" gradients along the diagonal - they are obscured by the overwhelming effect of the skip connection.
The Fix
This is tricky to fix if we take the decoder layer as a whole. However, if we focus on the self-attention layer (or the fully connected layer) within the decoder block, it's very manageable.
Let's take a look at this (slightly simplified) implementation of the decoder layer.
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, self_attn_weights = self.self_attn(...)
hidden_states = residual + hidden_states # Skip connection
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states # Skip connection
return hidden_states
Notice that the implementation is cleanly separated into two "blocks": Self-Attention
and Fully Connected
.
To make it a bit clearer, we might write it like this:
def forward(self, hidden_states):
residual = hidden_states
hidden_states = self.self_attention(hidden_states)
hidden_states = residual + hidden_states # <- Skip Connection
residual = hidden_states
hidden_states = self.fully_connected(hidden_states)
hidden_states = residual + hidden_states # <- Skip Connection
return hidden_states
First, we'll need a few hook helper methods to store a reference to the relevant tensors.
layer_ins = {}
layer_outs = {}
def _save_layer_output(tag):
def hook(_mod, args, kwargs, out):
logging.debug(f"Saving output from {tag}, type: {type(out)}")
if isinstance(out, tuple):
out = out[0]
layer_outs[tag] = out
return hook
def _save_layer_input(tag):
def hook(_mod, args, kwargs):
logging.debug(f"Saving input to {tag}, type: {type(args)}")
if isinstance(args, tuple):
args = args[0]
layer_ins[tag] = args
return hook
Then, we can register the hooks. register_forward_pre_hook
captures the input to an operation and register_forward_hook
captures the output. If you refer back to the complete decoder layer forward pass implementation, input_layernorm
and post_attention_layernorm
are the cutpoints in the computation graph we need to isolate the self-attention portion of the decoder layer.
def setup_hooks_attn():
for idx, layer in enumerate(model.model.layers):
layer.input_layernorm.register_forward_pre_hook(_save_layer_input(f"input_layernorm_{idx}"))
layer.post_attention_layernorm.register_forward_hook(_save_layer_output(f"post_attention_layernorm_{idx}"))
The gradient visualization for self-attention still shows the same diagonal dominance phenomena:
But it disappears if we simply subtract the identity matrix I
(note the updated colormap scale).
Why It Works
The explanation is simple. The self-attention layer with skip-connection has a Jacobian (read: partial differential or autograd gradient) with respect to x of .
Therefore, the partial differential of just the main branch is . In plain english, this simply means we should subtract the identity matrix from the gradient matrix.
This fix works equally well for the fully connected portion of the decoder layer.
Top comments (0)